aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-11-19 13:19:07 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-11-19 13:19:07 +0100
commitb5ae42349c4fd88022a63006060d72b03aa83b16 (patch)
tree18fecbdba6437003f574c15eddeb75bbed9980b3
parent5705a5e2efdcbbaffa5da00555b1afb3b4f9d2af (diff)
parent9ebab57c417d4fd19db6cf69f920a3adb1a1e092 (diff)
downloadwireguard-go-b5ae42349c4fd88022a63006060d72b03aa83b16.tar.gz
wireguard-go-b5ae42349c4fd88022a63006060d72b03aa83b16.zip
Merge branch 'source-caching'
-rw-r--r--src/conn.go109
-rw-r--r--src/conn_default.go122
-rw-r--r--src/conn_linux.go525
-rw-r--r--src/cookie.go12
-rw-r--r--src/cookie_test.go7
-rw-r--r--src/daemon_linux.go22
-rw-r--r--src/device.go40
-rw-r--r--src/helper_test.go8
-rw-r--r--src/main.go130
-rw-r--r--src/misc.go8
-rw-r--r--src/noise_test.go8
-rw-r--r--src/peer.go29
-rw-r--r--src/receive.go257
-rw-r--r--src/send.go23
-rwxr-xr-xsrc/tests/netns.sh104
-rw-r--r--src/timers.go29
-rw-r--r--src/tun.go6
-rw-r--r--src/tun_linux.go29
-rw-r--r--src/uapi.go81
-rw-r--r--src/uapi_linux.go117
20 files changed, 1178 insertions, 488 deletions
diff --git a/src/conn.go b/src/conn.go
index 2cf588d..5b40a23 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -2,10 +2,35 @@ package main
import (
"errors"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
"net"
- "time"
)
+/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
+ */
+type Bind interface {
+ SetMark(value uint32) error
+ ReceiveIPv6(buff []byte) (int, Endpoint, error)
+ ReceiveIPv4(buff []byte) (int, Endpoint, error)
+ Send(buff []byte, end Endpoint) error
+ Close() error
+}
+
+/* An Endpoint maintains the source/destination caching for a peer
+ *
+ * dst : the remote address of a peer ("endpoint" in uapi terminology)
+ * src : the local address from which datagrams originate going to the peer
+ */
+type Endpoint interface {
+ ClearSrc() // clears the source address
+ SrcToString() string // returns the local source address (ip:port)
+ DstToString() string // returns the destination address (ip:port)
+ DstToBytes() []byte // used for mac2 cookie calculations
+ DstIP() net.IP
+ SrcIP() net.IP
+}
+
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
@@ -27,63 +52,83 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err
}
-func updateUDPConn(device *Device) error {
+/* Must hold device and net lock
+ */
+func unsafeCloseUDPListener(device *Device) error {
+ var err error
+ netc := &device.net
+ if netc.bind != nil {
+ err = netc.bind.Close()
+ netc.bind = nil
+ }
+ return err
+}
+
+// must inform all listeners
+func UpdateUDPListener(device *Device) error {
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
+
netc := &device.net
netc.mutex.Lock()
defer netc.mutex.Unlock()
- // close existing connection
+ // close existing sockets
- if netc.conn != nil {
- netc.conn.Close()
- netc.conn = nil
-
- // We need for that fd to be closed in all other go routines, which
- // means we have to wait. TODO: find less horrible way of doing this.
- time.Sleep(time.Second / 2)
+ if err := unsafeCloseUDPListener(device); err != nil {
+ return err
}
- // open new connection
+ // assumption: netc.update WaitGroup should be exactly 1
+
+ // open new sockets
if device.tun.isUp.Get() {
- // listen on new address
+ device.log.Debug.Println("UDP bind updating")
- conn, err := net.ListenUDP("udp", netc.addr)
+ // bind to new port
+
+ var err error
+ netc.bind, netc.port, err = CreateBind(netc.port)
if err != nil {
+ netc.bind = nil
return err
}
- // set fwmark
+ // set mark
- err = setMark(netc.conn, netc.fwmark)
+ err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
- // retrieve port (may have been chosen by kernel)
+ // clear cached source addresses
+
+ for _, peer := range device.peers {
+ peer.mutex.Lock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ peer.mutex.Unlock()
+ }
- addr := conn.LocalAddr()
- netc.conn = conn
- netc.addr, _ = net.ResolveUDPAddr(
- addr.Network(),
- addr.String(),
- )
+ // decrease waitgroup to 0
- // notify goroutines
+ go device.RoutineReceiveIncomming(ipv4.Version, netc.bind)
+ go device.RoutineReceiveIncomming(ipv6.Version, netc.bind)
- signalSend(device.signal.newUDPConn)
+ device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
-func closeUDPConn(device *Device) {
- netc := &device.net
- netc.mutex.Lock()
- if netc.conn != nil {
- netc.conn.Close()
- }
- netc.mutex.Unlock()
- signalSend(device.signal.newUDPConn)
+func CloseUDPListener(device *Device) error {
+ device.mutex.Lock()
+ device.net.mutex.Lock()
+ err := unsafeCloseUDPListener(device)
+ device.net.mutex.Unlock()
+ device.mutex.Unlock()
+ return err
}
diff --git a/src/conn_default.go b/src/conn_default.go
index e7c60a8..5b73c90 100644
--- a/src/conn_default.go
+++ b/src/conn_default.go
@@ -6,6 +6,126 @@ import (
"net"
)
-func setMark(conn *net.UDPConn, value uint32) error {
+/* This code is meant to be a temporary solution
+ * on platforms for which the sticky socket / source caching behavior
+ * has not yet been implemented.
+ *
+ * See conn_linux.go for an implementation on the linux platform.
+ */
+
+type NativeBind struct {
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+}
+
+type NativeEndpoint net.UDPAddr
+
+var _ Bind = (*NativeBind)(nil)
+var _ Endpoint = (*NativeEndpoint)(nil)
+
+func CreateEndpoint(s string) (Endpoint, error) {
+ addr, err := parseEndpoint(s)
+ return (*NativeEndpoint)(addr), err
+}
+
+func (_ *NativeEndpoint) ClearSrc() {}
+
+func (e *NativeEndpoint) DstIP() net.IP {
+ return (*net.UDPAddr)(e).IP
+}
+
+func (e *NativeEndpoint) SrcIP() net.IP {
+ return nil // not supported
+}
+
+func (e *NativeEndpoint) DstToBytes() []byte {
+ addr := (*net.UDPAddr)(e)
+ out := addr.IP
+ out = append(out, byte(addr.Port&0xff))
+ out = append(out, byte((addr.Port>>8)&0xff))
+ return out
+}
+
+func (e *NativeEndpoint) DstToString() string {
+ return (*net.UDPAddr)(e).String()
+}
+
+func (e *NativeEndpoint) SrcToString() string {
+ return ""
+}
+
+func listenNet(network string, port int) (*net.UDPConn, int, error) {
+
+ // listen
+
+ conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // retrieve port
+
+ laddr := conn.LocalAddr()
+ uaddr, err := net.ResolveUDPAddr(
+ laddr.Network(),
+ laddr.String(),
+ )
+ if err != nil {
+ return nil, 0, err
+ }
+ return conn, uaddr.Port, nil
+}
+
+func CreateBind(uport uint16) (Bind, uint16, error) {
+ var err error
+ var bind NativeBind
+
+ port := int(uport)
+
+ bind.ipv4, port, err = listenNet("udp4", port)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ bind.ipv6, port, err = listenNet("udp6", port)
+ if err != nil {
+ bind.ipv4.Close()
+ return nil, 0, err
+ }
+
+ return &bind, uint16(port), nil
+}
+
+func (bind *NativeBind) Close() error {
+ err1 := bind.ipv4.Close()
+ err2 := bind.ipv6.Close()
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
+ return n, (*NativeEndpoint)(endpoint), err
+}
+
+func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
+ return n, (*NativeEndpoint)(endpoint), err
+}
+
+func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
+ var err error
+ nend := endpoint.(*NativeEndpoint)
+ if nend.IP.To16() != nil {
+ _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
+ } else {
+ _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
+ }
+ return err
+}
+
+func (bind *NativeBind) SetMark(_ uint32) error {
return nil
}
diff --git a/src/conn_linux.go b/src/conn_linux.go
index a349a9e..cdba74f 100644
--- a/src/conn_linux.go
+++ b/src/conn_linux.go
@@ -7,6 +7,7 @@
package main
import (
+ "encoding/binary"
"errors"
"golang.org/x/sys/unix"
"net"
@@ -16,19 +17,229 @@ import (
/* Supports source address caching
*
- * It is important that the endpoint is only updated after the packet content has been authenticated.
- *
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
+ * So this code is remains platform dependent.
*/
-type Endpoint struct {
- // source (selected based on dst type)
- // (could use RawSockaddrAny and unsafe)
- srcIPv6 unix.RawSockaddrInet6
- srcIPv4 unix.RawSockaddrInet4
- srcIf4 int32
-
- dst unix.RawSockaddrAny
+type NativeEndpoint struct {
+ src unix.RawSockaddrInet6
+ dst unix.RawSockaddrInet6
+}
+
+type NativeBind struct {
+ sock4 int
+ sock6 int
+}
+
+var _ Endpoint = (*NativeEndpoint)(nil)
+var _ Bind = NativeBind{}
+
+type IPv4Source struct {
+ src unix.RawSockaddrInet4
+ Ifindex int32
+}
+
+func htons(val uint16) uint16 {
+ var out [unsafe.Sizeof(val)]byte
+ binary.BigEndian.PutUint16(out[:], val)
+ return *((*uint16)(unsafe.Pointer(&out[0])))
+}
+
+func ntohs(val uint16) uint16 {
+ tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
+ return binary.BigEndian.Uint16((*tmp)[:])
+}
+
+func CreateEndpoint(s string) (Endpoint, error) {
+ var end NativeEndpoint
+ addr, err := parseEndpoint(s)
+ if err != nil {
+ return nil, err
+ }
+
+ ipv4 := addr.IP.To4()
+ if ipv4 != nil {
+ dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
+ dst.Family = unix.AF_INET
+ dst.Port = htons(uint16(addr.Port))
+ dst.Zero = [8]byte{}
+ copy(dst.Addr[:], ipv4)
+ end.ClearSrc()
+ return &end, nil
+ }
+
+ ipv6 := addr.IP.To16()
+ if ipv6 != nil {
+ zone, err := zoneToUint32(addr.Zone)
+ if err != nil {
+ return nil, err
+ }
+ dst := &end.dst
+ dst.Family = unix.AF_INET6
+ dst.Port = htons(uint16(addr.Port))
+ dst.Flowinfo = 0
+ dst.Scope_id = zone
+ copy(dst.Addr[:], ipv6[:])
+ end.ClearSrc()
+ return &end, nil
+ }
+
+ return nil, errors.New("Failed to recognize IP address format")
+}
+
+func CreateBind(port uint16) (Bind, uint16, error) {
+ var err error
+ var bind NativeBind
+
+ bind.sock6, port, err = create6(port)
+ if err != nil {
+ return nil, port, err
+ }
+
+ bind.sock4, port, err = create4(port)
+ if err != nil {
+ unix.Close(bind.sock6)
+ }
+ return bind, port, err
+}
+
+func (bind NativeBind) SetMark(value uint32) error {
+ err := unix.SetsockoptInt(
+ bind.sock6,
+ unix.SOL_SOCKET,
+ unix.SO_MARK,
+ int(value),
+ )
+
+ if err != nil {
+ return err
+ }
+
+ return unix.SetsockoptInt(
+ bind.sock4,
+ unix.SOL_SOCKET,
+ unix.SO_MARK,
+ int(value),
+ )
+}
+
+func closeUnblock(fd int) error {
+ // shutdown to unblock readers
+ unix.Shutdown(fd, unix.SHUT_RD)
+ return unix.Close(fd)
+}
+
+func (bind NativeBind) Close() error {
+ err1 := closeUnblock(bind.sock6)
+ err2 := closeUnblock(bind.sock4)
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ n, err := receive6(
+ bind.sock6,
+ buff,
+ &end,
+ )
+ return n, &end, err
+}
+
+func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ n, err := receive4(
+ bind.sock4,
+ buff,
+ &end,
+ )
+ return n, &end, err
+}
+
+func (bind NativeBind) Send(buff []byte, end Endpoint) error {
+ nend := end.(*NativeEndpoint)
+ switch nend.dst.Family {
+ case unix.AF_INET6:
+ return send6(bind.sock6, nend, buff)
+ case unix.AF_INET:
+ return send4(bind.sock4, nend, buff)
+ default:
+ return errors.New("Unknown address family of destination")
+ }
+}
+
+func sockaddrToString(addr unix.RawSockaddrInet6) string {
+ var udpAddr net.UDPAddr
+
+ switch addr.Family {
+ case unix.AF_INET6:
+ udpAddr.Port = int(ntohs(addr.Port))
+ udpAddr.IP = addr.Addr[:]
+ return udpAddr.String()
+
+ case unix.AF_INET:
+ ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
+ udpAddr.Port = int(ntohs(ptr.Port))
+ udpAddr.IP = net.IPv4(
+ ptr.Addr[0],
+ ptr.Addr[1],
+ ptr.Addr[2],
+ ptr.Addr[3],
+ )
+ return udpAddr.String()
+
+ default:
+ return "<unknown address family>"
+ }
+}
+
+func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
+ switch addr.Family {
+ case unix.AF_INET6:
+ return addr.Addr[:]
+ case unix.AF_INET:
+ ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
+ return net.IPv4(
+ ptr.Addr[0],
+ ptr.Addr[1],
+ ptr.Addr[2],
+ ptr.Addr[3],
+ )
+ default:
+ return nil
+ }
+}
+
+func (end *NativeEndpoint) SrcIP() net.IP {
+ return rawAddrToIP(end.src)
+}
+
+func (end *NativeEndpoint) DstIP() net.IP {
+ return rawAddrToIP(end.dst)
+}
+
+func (end *NativeEndpoint) DstToBytes() []byte {
+ ptr := unsafe.Pointer(&end.src)
+ arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
+ return arr[:]
+}
+
+func (end *NativeEndpoint) SrcToString() string {
+ return sockaddrToString(end.src)
+}
+
+func (end *NativeEndpoint) DstToString() string {
+ return sockaddrToString(end.dst)
+}
+
+func (end *NativeEndpoint) ClearDst() {
+ end.dst = unix.RawSockaddrInet6{}
+}
+
+func (end *NativeEndpoint) ClearSrc() {
+ end.src = unix.RawSockaddrInet6{}
}
func zoneToUint32(zone string) (uint32, error) {
@@ -42,51 +253,116 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err
}
-func (end *Endpoint) ClearSrc() {
- end.srcIf4 = 0
- end.srcIPv4 = unix.RawSockaddrInet4{}
- end.srcIPv6 = unix.RawSockaddrInet6{}
-}
+func create4(port uint16) (int, uint16, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET,
+ unix.SOCK_DGRAM,
+ 0,
+ )
-func (end *Endpoint) Set(s string) error {
- addr, err := parseEndpoint(s)
if err != nil {
- return err
+ return -1, 0, err
}
- ipv6 := addr.IP.To16()
- if ipv6 != nil {
- zone, err := zoneToUint32(addr.Zone)
- if err != nil {
+ addr := unix.SockaddrInet4{
+ Port: int(port),
+ }
+
+ // set sockopts and bind
+
+ if err := func() error {
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
return err
}
- ptr := (*unix.RawSockaddrInet6)(unsafe.Pointer(&end.dst))
- ptr.Family = unix.AF_INET6
- ptr.Port = uint16(addr.Port)
- ptr.Flowinfo = 0
- ptr.Scope_id = zone
- copy(ptr.Addr[:], ipv6[:])
- end.ClearSrc()
- return nil
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IP,
+ unix.IP_PKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ return unix.Bind(fd, &addr)
+ }(); err != nil {
+ unix.Close(fd)
}
- ipv4 := addr.IP.To4()
- if ipv4 != nil {
- ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
- ptr.Family = unix.AF_INET
- ptr.Port = uint16(addr.Port)
- ptr.Zero = [8]byte{}
- copy(ptr.Addr[:], ipv4)
- end.ClearSrc()
- return nil
+ return fd, uint16(addr.Port), err
+}
+
+func create6(port uint16) (int, uint16, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET6,
+ unix.SOCK_DGRAM,
+ 0,
+ )
+
+ if err != nil {
+ return -1, 0, err
+ }
+
+ // set sockopts and bind
+
+ addr := unix.SockaddrInet6{
+ Port: int(port),
+ }
+
+ if err := func() error {
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_RECVPKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_V6ONLY,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ return unix.Bind(fd, &addr)
+
+ }(); err != nil {
+ unix.Close(fd)
}
- return errors.New("Failed to recognize IP address format")
+ return fd, uint16(addr.Port), err
}
-func send6(sock uintptr, end *Endpoint, buff []byte) error {
- var iovec unix.Iovec
+func send6(sock int, end *NativeEndpoint, buff []byte) error {
+
+ // construct message header
+ var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
@@ -97,11 +373,11 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo,
+ Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
- Addr: end.srcIPv6.Addr,
- Ifindex: end.srcIPv6.Scope_id,
+ Addr: end.src.Addr,
+ Ifindex: end.src.Scope_id,
},
}
@@ -119,22 +395,41 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
_, _, errno := unix.Syscall(
unix.SYS_SENDMSG,
- sock,
+ uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
+
+ if errno == 0 {
+ return nil
+ }
+
+ // clear src and retry
+
if errno == unix.EINVAL {
end.ClearSrc()
+ cmsg.pktinfo = unix.Inet6Pktinfo{}
+ _, _, errno = unix.Syscall(
+ unix.SYS_SENDMSG,
+ uintptr(sock),
+ uintptr(unsafe.Pointer(&msghdr)),
+ 0,
+ )
}
+
return errno
}
-func send4(sock uintptr, end *Endpoint, buff []byte) error {
- var iovec unix.Iovec
+func send4(sock int, end *NativeEndpoint, buff []byte) error {
+ // construct message header
+
+ var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
+ src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
+
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
@@ -142,11 +437,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet6Pktinfo,
+ Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
- Spec_dst: end.srcIPv4.Addr,
- Ifindex: end.srcIf4,
+ Spec_dst: src4.src.Addr,
+ Ifindex: src4.Ifindex,
},
}
@@ -156,51 +451,92 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
Name: (*byte)(unsafe.Pointer(&end.dst)),
Namelen: unix.SizeofSockaddrInet4,
Control: (*byte)(unsafe.Pointer(&cmsg)),
+ Flags: 0,
}
-
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
// sendmsg(sock, &msghdr, 0)
_, _, errno := unix.Syscall(
unix.SYS_SENDMSG,
- sock,
+ uintptr(sock),
uintptr(unsafe.Pointer(&msghdr)),
0,
)
+
+ // clear source and try again
+
if errno == unix.EINVAL {
end.ClearSrc()
+ cmsg.pktinfo = unix.Inet4Pktinfo{}
+ _, _, errno = unix.Syscall(
+ unix.SYS_SENDMSG,
+ uintptr(sock),
+ uintptr(unsafe.Pointer(&msghdr)),
+ 0,
+ )
+ }
+
+ // errno = 0 is still an error instance
+
+ if errno == 0 {
+ return nil
}
+
return errno
}
-func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
+func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
- // extract underlying file descriptor
+ // contruct message header
- file, err := c.File()
- if err != nil {
- return err
+ var iovec unix.Iovec
+ iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
+ iovec.SetLen(len(buff))
+
+ var cmsg struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet4Pktinfo
+ }
+
+ var msghdr unix.Msghdr
+ msghdr.Iov = &iovec
+ msghdr.Iovlen = 1
+ msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
+ msghdr.Namelen = unix.SizeofSockaddrInet4
+ msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
+ msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
+
+ // recvmsg(sock, &mskhdr, 0)
+
+ size, _, errno := unix.Syscall(
+ unix.SYS_RECVMSG,
+ uintptr(sock),
+ uintptr(unsafe.Pointer(&msghdr)),
+ 0,
+ )
+
+ if errno != 0 {
+ return 0, errno
}
- sock := file.Fd()
- // send depending on address family of dst
+ // update source cache
- family := *((*uint16)(unsafe.Pointer(&end.dst)))
- if family == unix.AF_INET {
- return send4(sock, end, buff)
- } else if family == unix.AF_INET6 {
- return send6(sock, end, buff)
+ if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
+ cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
+ cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
+ src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
+ src4.src.Family = unix.AF_INET
+ src4.src.Addr = cmsg.pktinfo.Spec_dst
+ src4.Ifindex = cmsg.pktinfo.Ifindex
}
- return errors.New("Unknown address family of source")
+
+ return int(size), nil
}
-func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
+func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
- file, err := c.File()
- if err != nil {
- return err, nil, nil
- }
+ // contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
@@ -208,60 +544,39 @@ func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAdd
var cmsg struct {
cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo // big enough
+ pktinfo unix.Inet6Pktinfo
}
var msg unix.Msghdr
msg.Iov = &iovec
msg.Iovlen = 1
msg.Name = (*byte)(unsafe.Pointer(&end.dst))
- msg.Namelen = uint32(unix.SizeofSockaddrAny)
+ msg.Namelen = uint32(unix.SizeofSockaddrInet6)
msg.Control = (*byte)(unsafe.Pointer(&cmsg))
msg.SetControllen(int(unsafe.Sizeof(cmsg)))
- _, _, errno := unix.Syscall(
+ // recvmsg(sock, &mskhdr, 0)
+
+ size, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
- file.Fd(),
+ uintptr(sock),
uintptr(unsafe.Pointer(&msg)),
0,
)
if errno != 0 {
- return errno, nil, nil
+ return 0, errno
}
+ // update source cache
+
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
-
- }
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
-
- info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
- println(info)
-
- }
-
- return nil, nil, nil
-}
-
-func setMark(conn *net.UDPConn, value uint32) error {
- if conn == nil {
- return nil
+ end.src.Family = unix.AF_INET6
+ end.src.Addr = cmsg.pktinfo.Addr
+ end.src.Scope_id = cmsg.pktinfo.Ifindex
}
- file, err := conn.File()
- if err != nil {
- return err
- }
-
- return unix.SetsockoptInt(
- int(file.Fd()),
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
+ return int(size), nil
}
diff --git a/src/cookie.go b/src/cookie.go
index a81819b..a13ad49 100644
--- a/src/cookie.go
+++ b/src/cookie.go
@@ -5,10 +5,8 @@ import (
"crypto/rand"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
- "net"
"sync"
"time"
- "unsafe"
)
type CookieChecker struct {
@@ -76,7 +74,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2])
}
-func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
+func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
st.mutex.RLock()
defer st.mutex.RUnlock()
@@ -89,8 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
var cookie [blake2s.Size128]byte
func() {
mac, _ := blake2s.New128(st.mac2.secret[:])
- mac.Write(src.IP)
- mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
+ mac.Write(src)
mac.Sum(cookie[:0])
}()
@@ -111,7 +108,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src *net.UDPAddr) bool {
func (st *CookieChecker) CreateReply(
msg []byte,
recv uint32,
- src *net.UDPAddr,
+ src []byte,
) (*MessageCookieReply, error) {
st.mutex.RLock()
@@ -136,8 +133,7 @@ func (st *CookieChecker) CreateReply(
var cookie [blake2s.Size128]byte
func() {
mac, _ := blake2s.New128(st.mac2.secret[:])
- mac.Write(src.IP)
- mac.Write((*[unsafe.Sizeof(src.Port)]byte)(unsafe.Pointer(&src.Port))[:])
+ mac.Write(src)
mac.Sum(cookie[:0])
}()
diff --git a/src/cookie_test.go b/src/cookie_test.go
index 193a76e..d745fe7 100644
--- a/src/cookie_test.go
+++ b/src/cookie_test.go
@@ -1,7 +1,6 @@
package main
import (
- "net"
"testing"
)
@@ -25,7 +24,7 @@ func TestCookieMAC1(t *testing.T) {
// check mac1
- src, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4000")
+ src := []byte{192, 168, 13, 37, 10, 10, 10}
checkMAC1 := func(msg []byte) {
generator.AddMacs(msg)
@@ -128,12 +127,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20
- srcBad1, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4001")
+ srcBad1 := []byte{192, 168, 13, 37, 40, 01}
if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed")
}
- srcBad2, _ := net.ResolveUDPAddr("udp", "192.168.13.38:4000")
+ srcBad2 := []byte{192, 168, 13, 38, 40, 01}
if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed")
}
diff --git a/src/daemon_linux.go b/src/daemon_linux.go
index 730f89e..e1aaede 100644
--- a/src/daemon_linux.go
+++ b/src/daemon_linux.go
@@ -2,29 +2,25 @@ package main
import (
"os"
+ "os/exec"
)
/* Daemonizes the process on linux
*
* This is done by spawning and releasing a copy with the --foreground flag
- *
- * TODO: Use env variable to spawn in background
*/
+func Daemonize(attr *os.ProcAttr) error {
+ // I would like to use os.Executable,
+ // however this means dropping support for Go <1.8
+ path, err := exec.LookPath(os.Args[0])
+ if err != nil {
+ return err
+ }
-func Daemonize() error {
argv := []string{os.Args[0], "--foreground"}
argv = append(argv, os.Args[1:]...)
- attr := &os.ProcAttr{
- Dir: ".",
- Env: os.Environ(),
- Files: []*os.File{
- os.Stdin,
- nil,
- nil,
- },
- }
process, err := os.StartProcess(
- argv[0],
+ path,
argv,
attr,
)
diff --git a/src/device.go b/src/device.go
index 8567a36..76235bd 100644
--- a/src/device.go
+++ b/src/device.go
@@ -1,7 +1,6 @@
package main
import (
- "net"
"runtime"
"sync"
"sync/atomic"
@@ -9,8 +8,9 @@ import (
)
type Device struct {
- log *Logger // collection of loggers for levels
- idCounter uint // for assigning debug ids to peers
+ closed AtomicBool // device is closed? (acting as guard)
+ log *Logger // collection of loggers for levels
+ idCounter uint // for assigning debug ids to peers
fwMark uint32
tun struct {
device TUNDevice
@@ -22,9 +22,9 @@ type Device struct {
}
net struct {
mutex sync.RWMutex
- addr *net.UDPAddr // UDP source address
- conn *net.UDPConn // UDP "connection"
- fwmark uint32
+ bind Bind // bind interface
+ port uint16 // listening port
+ fwmark uint32 // mark value (0 = disabled)
}
mutex sync.RWMutex
privateKey NoisePrivateKey
@@ -37,8 +37,7 @@ type Device struct {
handshake chan QueueHandshakeElement
}
signal struct {
- stop chan struct{} // halts all go routines
- newUDPConn chan struct{} // a net.conn was set (consumed by the receiver routine)
+ stop chan struct{}
}
underLoadUntil atomic.Value
ratelimiter Ratelimiter
@@ -128,21 +127,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
device.pool.messageBuffers.Put(msg)
}
-func NewDevice(tun TUNDevice, logLevel int) *Device {
+func NewDevice(tun TUNDevice, logger *Logger) *Device {
device := new(Device)
device.mutex.Lock()
defer device.mutex.Unlock()
- device.log = NewLogger(logLevel, "("+tun.Name()+") ")
+ device.log = logger
device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun
+
device.indices.Init()
device.ratelimiter.Init()
+
device.routingTable.Reset()
device.underLoadUntil.Store(time.Time{})
- // setup pools
+ // setup buffer pool
device.pool.messageBuffers = sync.Pool{
New: func() interface{} {
@@ -159,7 +160,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
// prepare signals
device.signal.stop = make(chan struct{})
- device.signal.newUDPConn = make(chan struct{}, 1)
+
+ // prepare net
+
+ device.net.port = 0
+ device.net.bind = nil
// start workers
@@ -168,12 +173,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
go device.RoutineDecryption()
go device.RoutineHandshake()
}
-
+ go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
- go device.RoutineReadFromTUN()
- go device.RoutineReceiveIncomming()
-
return device
}
@@ -202,9 +204,13 @@ func (device *Device) RemoveAllPeers() {
}
func (device *Device) Close() {
+ if device.closed.Swap(true) {
+ return
+ }
+ device.log.Info.Println("Closing device")
device.RemoveAllPeers()
close(device.signal.stop)
- closeUDPConn(device)
+ CloseUDPListener(device)
device.tun.device.Close()
}
diff --git a/src/helper_test.go b/src/helper_test.go
index fc171e8..8548121 100644
--- a/src/helper_test.go
+++ b/src/helper_test.go
@@ -2,6 +2,7 @@ package main
import (
"bytes"
+ "os"
"testing"
)
@@ -15,6 +16,10 @@ type DummyTUN struct {
events chan TUNEvent
}
+func (tun *DummyTUN) File() *os.File {
+ return nil
+}
+
func (tun *DummyTUN) Name() string {
return tun.name
}
@@ -67,7 +72,8 @@ func randDevice(t *testing.T) *Device {
t.Fatal(err)
}
tun, _ := CreateDummyTUN("dummy")
- device := NewDevice(tun, LogLevelError)
+ logger := NewLogger(LogLevelError, "")
+ device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}
diff --git a/src/main.go b/src/main.go
index 196a4c6..7d86716 100644
--- a/src/main.go
+++ b/src/main.go
@@ -2,10 +2,15 @@ package main
import (
"fmt"
- "log"
"os"
"os/signal"
"runtime"
+ "strconv"
+)
+
+const (
+ ENV_WG_TUN_FD = "WG_TUN_FD"
+ ENV_WG_UAPI_FD = "WG_UAPI_FD"
)
func printUsage() {
@@ -43,28 +48,6 @@ func main() {
interfaceName = os.Args[1]
}
- // daemonize the process
-
- if !foreground {
- err := Daemonize()
- if err != nil {
- log.Println("Failed to daemonize:", err)
- }
- return
- }
-
- // increase number of go workers (for Go <1.5)
-
- runtime.GOMAXPROCS(runtime.NumCPU())
-
- // open TUN device
-
- tun, err := CreateTUN(interfaceName)
- if err != nil {
- log.Println("Failed to create tun device:", err)
- return
- }
-
// get log level (default: info)
logLevel := func() int {
@@ -79,25 +62,103 @@ func main() {
return LogLevelInfo
}()
- // create wireguard device
+ logger := NewLogger(
+ logLevel,
+ fmt.Sprintf("(%s) ", interfaceName),
+ )
+
+ logger.Debug.Println("Debug log enabled")
+
+ // open TUN device (or use supplied fd)
+
+ tun, err := func() (TUNDevice, error) {
+ tunFdStr := os.Getenv(ENV_WG_TUN_FD)
+ if tunFdStr == "" {
+ return CreateTUN(interfaceName)
+ }
+
+ // construct tun device from supplied fd
+
+ fd, err := strconv.ParseUint(tunFdStr, 10, 32)
+ if err != nil {
+ return nil, err
+ }
+
+ file := os.NewFile(uintptr(fd), "")
+ return CreateTUNFromFile(interfaceName, file)
+ }()
+
+ if err != nil {
+ logger.Error.Println("Failed to create TUN device:", err)
+ os.Exit(ExitSetupFailed)
+ }
- device := NewDevice(tun, logLevel)
+ // open UAPI file (or use supplied fd)
- logInfo := device.log.Info
- logError := device.log.Error
- logInfo.Println("Starting device")
+ fileUAPI, err := func() (*os.File, error) {
+ uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
+ if uapiFdStr == "" {
+ return UAPIOpen(interfaceName)
+ }
+
+ // use supplied fd
- // start configuration lister
+ fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
+ if err != nil {
+ return nil, err
+ }
+
+ return os.NewFile(uintptr(fd), ""), nil
+ }()
- uapi, err := NewUAPIListener(interfaceName)
if err != nil {
- logError.Fatal("UAPI listen error:", err)
+ logger.Error.Println("UAPI listen error:", err)
+ os.Exit(ExitSetupFailed)
+ return
}
+ // daemonize the process
+
+ if !foreground {
+ env := os.Environ()
+ env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
+ env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
+ attr := &os.ProcAttr{
+ Files: []*os.File{
+ nil, // stdin
+ nil, // stdout
+ nil, // stderr
+ tun.File(),
+ fileUAPI,
+ },
+ Dir: ".",
+ Env: env,
+ }
+ err = Daemonize(attr)
+ if err != nil {
+ logger.Error.Println("Failed to daemonize:", err)
+ os.Exit(ExitSetupFailed)
+ }
+ return
+ }
+
+ // increase number of go workers (for Go <1.5)
+
+ runtime.GOMAXPROCS(runtime.NumCPU())
+
+ // create wireguard device
+
+ device := NewDevice(tun, logger)
+
+ logger.Info.Println("Device started")
+
+ // start uapi listener
errs := make(chan error)
term := make(chan os.Signal)
wait := device.WaitChannel()
+ uapi, err := UAPIListen(interfaceName, fileUAPI)
+
go func() {
for {
conn, err := uapi.Accept()
@@ -109,7 +170,7 @@ func main() {
}
}()
- logInfo.Println("UAPI listener started")
+ logger.Info.Println("UAPI listener started")
// wait for program to terminate
@@ -122,9 +183,10 @@ func main() {
case <-errs:
}
- // clean up UAPI bind
+ // clean up
uapi.Close()
+ device.Close()
- logInfo.Println("Closing")
+ logger.Info.Println("Shutting down")
}
diff --git a/src/misc.go b/src/misc.go
index bbe0d68..b43e97e 100644
--- a/src/misc.go
+++ b/src/misc.go
@@ -21,6 +21,14 @@ func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.flag) == AtomicTrue
}
+func (a *AtomicBool) Swap(val bool) bool {
+ flag := AtomicFalse
+ if val {
+ flag = AtomicTrue
+ }
+ return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
+}
+
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {
diff --git a/src/noise_test.go b/src/noise_test.go
index 48408f9..0d7f0e9 100644
--- a/src/noise_test.go
+++ b/src/noise_test.go
@@ -117,8 +117,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error
var out []byte
var nonce [12]byte
- out = key1.send.aead.Seal(out, nonce[:], testMsg, nil)
- out, err = key2.receive.aead.Open(out[:0], nonce[:], out, nil)
+ out = key1.send.Seal(out, nonce[:], testMsg, nil)
+ out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()
@@ -128,8 +128,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error
var out []byte
var nonce [12]byte
- out = key2.send.aead.Seal(out, nonce[:], testMsg, nil)
- out, err = key1.receive.aead.Open(out[:0], nonce[:], out, nil)
+ out = key2.send.Seal(out, nonce[:], testMsg, nil)
+ out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()
diff --git a/src/peer.go b/src/peer.go
index 6fea829..f3eb6c2 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -4,7 +4,6 @@ import (
"encoding/base64"
"errors"
"fmt"
- "net"
"sync"
"time"
)
@@ -16,7 +15,7 @@ type Peer struct {
keyPairs KeyPairs
handshake Handshake
device *Device
- endpoint *net.UDPAddr
+ endpoint Endpoint
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
@@ -106,6 +105,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
handshake.mutex.Unlock()
+ // reset endpoint
+
+ peer.endpoint = nil
+
// prepare queuing
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
@@ -130,11 +133,31 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
+func (peer *Peer) SendBuffer(buffer []byte) error {
+ peer.device.net.mutex.RLock()
+ defer peer.device.net.mutex.RUnlock()
+ peer.mutex.RLock()
+ defer peer.mutex.RUnlock()
+ if peer.endpoint == nil {
+ return errors.New("No known endpoint for peer")
+ }
+ return peer.device.net.bind.Send(buffer, peer.endpoint)
+}
+
+/* Returns a short string identification for logging
+ */
func (peer *Peer) String() string {
+ if peer.endpoint == nil {
+ return fmt.Sprintf(
+ "peer(%d unknown %s)",
+ peer.id,
+ base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
+ )
+ }
return fmt.Sprintf(
"peer(%d %s %s)",
peer.id,
- peer.endpoint.String(),
+ peer.endpoint.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
diff --git a/src/receive.go b/src/receive.go
index 52c2718..27fdb8a 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -13,19 +13,20 @@ import (
)
type QueueHandshakeElement struct {
- msgType uint32
- packet []byte
- buffer *[MaxMessageSize]byte
- source *net.UDPAddr
+ msgType uint32
+ packet []byte
+ endpoint Endpoint
+ buffer *[MaxMessageSize]byte
}
type QueueInboundElement struct {
- dropped int32
- mutex sync.Mutex
- buffer *[MaxMessageSize]byte
- packet []byte
- counter uint64
- keyPair *KeyPair
+ dropped int32
+ mutex sync.Mutex
+ buffer *[MaxMessageSize]byte
+ packet []byte
+ counter uint64
+ keyPair *KeyPair
+ endpoint Endpoint
}
func (elem *QueueInboundElement) Drop() {
@@ -92,130 +93,122 @@ func (device *Device) addToHandshakeQueue(
}
}
-func (device *Device) RoutineReceiveIncomming() {
+func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) {
logDebug := device.log.Debug
- logDebug.Println("Routine, receive incomming, started")
+ logDebug.Println("Routine, receive incomming, IP version:", IP)
for {
- // wait for new conn
+ // receive datagrams until conn is closed
- logDebug.Println("Waiting for udp socket")
+ buffer := device.GetMessageBuffer()
- select {
- case <-device.signal.stop:
- return
+ var (
+ err error
+ size int
+ endpoint Endpoint
+ )
- case <-device.signal.newUDPConn:
+ for {
- // fetch connection
+ // read next datagram
- device.net.mutex.RLock()
- conn := device.net.conn
- device.net.mutex.RUnlock()
- if conn == nil {
+ switch IP {
+ case ipv4.Version:
+ size, endpoint, err = bind.ReceiveIPv4(buffer[:])
+ case ipv6.Version:
+ size, endpoint, err = bind.ReceiveIPv6(buffer[:])
+ default:
+ return
+ }
+
+ if err != nil {
+ break
+ }
+
+ if size < MinMessageSize {
continue
}
- logDebug.Println("Listening for inbound packets")
+ // check size of packet
- // receive datagrams until conn is closed
+ packet := buffer[:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- buffer := device.GetMessageBuffer()
+ var okay bool
- for {
+ switch msgType {
- // read next datagram
+ // check if transport
- size, raddr, err := conn.ReadFromUDP(buffer[:])
+ case MessageTransportType:
- if err != nil {
- break
- }
+ // check size
- if size < MinMessageSize {
+ if len(packet) < MessageTransportType {
continue
}
- // check size of packet
-
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
-
- var okay bool
+ // lookup key pair
- switch msgType {
-
- // check if transport
-
- case MessageTransportType:
-
- // check size
-
- if len(packet) < MessageTransportType {
- continue
- }
-
- // lookup key pair
-
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indices.Lookup(receiver)
- keyPair := value.keyPair
- if keyPair == nil {
- continue
- }
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indices.Lookup(receiver)
+ keyPair := value.keyPair
+ if keyPair == nil {
+ continue
+ }
- // check key-pair expiry
+ // check key-pair expiry
- if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
- continue
- }
+ if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- // create work element
+ // create work element
- peer := value.peer
- elem := &QueueInboundElement{
- packet: packet,
- buffer: buffer,
- keyPair: keyPair,
- dropped: AtomicFalse,
- }
- elem.mutex.Lock()
+ peer := value.peer
+ elem := &QueueInboundElement{
+ packet: packet,
+ buffer: buffer,
+ keyPair: keyPair,
+ dropped: AtomicFalse,
+ endpoint: endpoint,
+ }
+ elem.mutex.Lock()
- // add to decryption queues
+ // add to decryption queues
- device.addToDecryptionQueue(device.queue.decryption, elem)
- device.addToInboundQueue(peer.queue.inbound, elem)
- buffer = device.GetMessageBuffer()
- continue
+ device.addToDecryptionQueue(device.queue.decryption, elem)
+ device.addToInboundQueue(peer.queue.inbound, elem)
+ buffer = device.GetMessageBuffer()
+ continue
- // otherwise it is a handshake related packet
+ // otherwise it is a fixed size & handshake related packet
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageInitiationType:
+ okay = len(packet) == MessageInitiationSize
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageResponseType:
+ okay = len(packet) == MessageResponseSize
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
- }
+ case MessageCookieReplyType:
+ okay = len(packet) == MessageCookieReplySize
+ }
- if okay {
- device.addToHandshakeQueue(
- device.queue.handshake,
- QueueHandshakeElement{
- msgType: msgType,
- buffer: buffer,
- packet: packet,
- source: raddr,
- },
- )
- buffer = device.GetMessageBuffer()
- }
+ if okay {
+ device.addToHandshakeQueue(
+ device.queue.handshake,
+ QueueHandshakeElement{
+ msgType: msgType,
+ buffer: buffer,
+ packet: packet,
+ endpoint: endpoint,
+ },
+ )
+ buffer = device.GetMessageBuffer()
}
}
}
@@ -293,8 +286,6 @@ func (device *Device) RoutineHandshake() {
// unmarshal packet
- logDebug.Println("Process cookie reply from:", elem.source.String())
-
var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
@@ -321,15 +312,25 @@ func (device *Device) RoutineHandshake() {
return
}
+ // endpoints destination address is the source of the datagram
+
+ srcBytes := elem.endpoint.DstToBytes()
+
if device.IsUnderLoad() {
- if !device.mac.CheckMAC2(elem.packet, elem.source) {
+
+ // verify MAC2 field
+
+ if !device.mac.CheckMAC2(elem.packet, srcBytes) {
// construct cookie reply
- logDebug.Println("Sending cookie reply to:", elem.source.String())
+ logDebug.Println(
+ "Sending cookie reply to:",
+ elem.endpoint.DstToString(),
+ )
- sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
- reply, err := device.mac.CreateReply(elem.packet, sender, elem.source)
+ sender := binary.LittleEndian.Uint32(elem.packet[4:8])
+ reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil {
logError.Println("Failed to create cookie reply:", err)
return
@@ -339,17 +340,16 @@ func (device *Device) RoutineHandshake() {
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, reply)
- _, err = device.net.conn.WriteToUDP(
- writer.Bytes(),
- elem.source,
- )
+ device.net.bind.Send(writer.Bytes(), elem.endpoint)
if err != nil {
logDebug.Println("Failed to send cookie reply:", err)
}
continue
}
- if !device.ratelimiter.Allow(elem.source.IP) {
+ // check ratelimiter
+
+ if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
continue
}
}
@@ -380,8 +380,7 @@ func (device *Device) RoutineHandshake() {
if peer == nil {
logInfo.Println(
"Recieved invalid initiation message from",
- elem.source.IP.String(),
- elem.source.Port,
+ elem.endpoint.DstToString(),
)
continue
}
@@ -392,10 +391,9 @@ func (device *Device) RoutineHandshake() {
peer.TimerAnyAuthenticatedPacketReceived()
// update endpoint
- // TODO: Discover destination address also, only update on change
peer.mutex.Lock()
- peer.endpoint = elem.source
+ peer.endpoint = elem.endpoint
peer.mutex.Unlock()
// create response
@@ -418,9 +416,11 @@ func (device *Device) RoutineHandshake() {
// send response
- _, err = peer.SendBuffer(packet)
+ err = peer.SendBuffer(packet)
if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
+ } else {
+ logError.Println("Failed to send response to:", peer.String(), err)
}
case MessageResponseType:
@@ -441,12 +441,17 @@ func (device *Device) RoutineHandshake() {
if peer == nil {
logInfo.Println(
"Recieved invalid response message from",
- elem.source.IP.String(),
- elem.source.Port,
+ elem.endpoint.DstToString(),
)
continue
}
+ // update endpoint
+
+ peer.mutex.Lock()
+ peer.endpoint = elem.endpoint
+ peer.mutex.Unlock()
+
logDebug.Println("Received handshake initation from", peer)
peer.TimerEphemeralKeyCreated()
@@ -515,6 +520,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
}
kp.mutex.Unlock()
+ // update endpoint
+
+ peer.mutex.Lock()
+ peer.endpoint = elem.endpoint
+ peer.mutex.Unlock()
+
// check for keep-alive
if len(elem.packet) == 0 {
@@ -546,7 +557,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer {
- logInfo.Println("Packet with unallowed source IP from", peer.String())
+ logInfo.Println(
+ "IPv4 packet with unallowed source address from",
+ peer.String(),
+ )
continue
}
@@ -571,7 +585,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer {
- logInfo.Println("Packet with unallowed source IP from", peer.String())
+ logInfo.Println(
+ "IPv6 packet with unallowed source address from",
+ peer.String(),
+ )
continue
}
@@ -580,7 +597,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue
}
- // write to tun
+ // write to tun device
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.packet)
diff --git a/src/send.go b/src/send.go
index 5c88ead..52872f6 100644
--- a/src/send.go
+++ b/src/send.go
@@ -2,7 +2,6 @@ package main
import (
"encoding/binary"
- "errors"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
@@ -105,26 +104,6 @@ func addToEncryptionQueue(
}
}
-func (peer *Peer) SendBuffer(buffer []byte) (int, error) {
- peer.device.net.mutex.RLock()
- defer peer.device.net.mutex.RUnlock()
-
- peer.mutex.RLock()
- defer peer.mutex.RUnlock()
-
- endpoint := peer.endpoint
- if endpoint == nil {
- return 0, errors.New("No known endpoint for peer")
- }
-
- conn := peer.device.net.conn
- if conn == nil {
- return 0, errors.New("No UDP socket for device")
- }
-
- return conn.WriteToUDP(buffer, endpoint)
-}
-
/* Reads packets from the TUN and inserts
* into nonce queue for peer
*
@@ -343,7 +322,7 @@ func (peer *Peer) RoutineSequentialSender() {
// send message and return buffer to pool
length := uint64(len(elem.packet))
- _, err := peer.SendBuffer(elem.packet)
+ err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer)
if err != nil {
logDebug.Println("Failed to send authenticated packet to peer", peer.String())
diff --git a/src/tests/netns.sh b/src/tests/netns.sh
index 043da3e..22abea8 100755
--- a/src/tests/netns.sh
+++ b/src/tests/netns.sh
@@ -20,6 +20,14 @@
# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
# details on how this is accomplished.
+
+# This code is ported to the WireGuard-Go directly from the kernel project.
+#
+# Please ensure that you have installed the newest version of the WireGuard
+# tools from the WireGuard project and before running these tests as:
+#
+# ./netns.sh <path to wireguard-go>
+
set -e
exec 3>&1
@@ -27,8 +35,8 @@ export WG_HIDE_KEYS=never
netns0="wg-test-$$-0"
netns1="wg-test-$$-1"
netns2="wg-test-$$-2"
-program="../wireguard-go"
-export LOG_LEVEL="error"
+program=$1
+export LOG_LEVEL="info"
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
pp() { pretty "" "$*"; "$@"; }
@@ -72,13 +80,11 @@ pp ip netns add $netns2
ip0 link set up dev lo
# ip0 link add dev wg1 type wireguard
-n0 $program -f wg1 &
-sleep 1
+n0 $program wg1
ip0 link set wg1 netns $netns1
# ip0 link add dev wg1 type wireguard
-n0 $program -f wg2 &
-sleep 1
+n0 $program wg2
ip0 link set wg2 netns $netns2
key1="$(pp wg genkey)"
@@ -185,14 +191,14 @@ ip0 -4 addr del 127.0.0.1/8 dev lo
ip0 -4 addr add 127.212.121.99/8 dev lo
n0 wg set wg1 listen-port 9999
n0 wg set wg1 peer "$pub2" endpoint 127.0.0.1:20000
-n1 ping6 -W 1 -c 1 fd00::20000
-[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]]
+n1 ping6 -W 1 -c 1 fd00::2
+[[ $(n2 wg show wg2 endpoints) == "$pub1 127.212.121.99:9999" ]]
# Test using IPv6 that roaming works
n1 wg set wg1 listen-port 9998
n1 wg set wg1 peer "$pub2" endpoint [::1]:20000
n1 ping -W 1 -c 1 192.168.241.2
-[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]]
+[[ $(n2 wg show wg2 endpoints) == "$pub1 [::1]:9998" ]]
# Test that crypto-RP filter works
n1 wg set wg1 peer "$pub2" allowed-ips 192.168.241.0/24
@@ -212,7 +218,7 @@ n2 ncat -u 192.168.241.1 1111 <<<"X"
! read -r -N 1 -t 1 out <&4
kill $nmap_pid
n0 wg set wg1 peer "$more_specific_key" remove
-[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]]
+[[ $(n1 wg show wg1 endpoints) == "$pub2 [::1]:9997" ]]
ip1 link del wg1
ip2 link del wg2
@@ -263,7 +269,7 @@ n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to
n0 wg set wg1 peer "$pub2" endpoint 10.0.0.100:20000 persistent-keepalive 1
n1 ping -W 1 -c 1 192.168.241.2
n2 ping -W 1 -c 1 192.168.241.1
-[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
+[[ $(n2 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`).
pp sleep 3
n2 ping -W 1 -c 1 192.168.241.1
@@ -289,7 +295,7 @@ ip2 link del wg2
# ip1 link add dev wg1 type wireguard
# ip2 link add dev wg1 type wireguard
n1 $program wg1
-n2 $program wg1
+n2 $program wg2
configure_peers
@@ -336,17 +342,83 @@ waitiface $netns1 veth1
waitiface $netns2 veth2
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.1:10000
n2 ping -W 1 -c 1 192.168.241.1
-[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
+[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.1:10000" ]]
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::1]:10000
n2 ping -W 1 -c 1 192.168.241.1
-[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]]
+[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::1]:10000" ]]
n0 wg set wg2 peer "$pub1" endpoint 10.0.0.2:10000
n2 ping -W 1 -c 1 192.168.241.1
-[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]]
+[[ $(n0 wg show wg2 endpoints) == "$pub1 10.0.0.2:10000" ]]
n0 wg set wg2 peer "$pub1" endpoint [fd00:aa::2]:10000
n2 ping -W 1 -c 1 192.168.241.1
-[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]]
+[[ $(n0 wg show wg2 endpoints) == "$pub1 [fd00:aa::2]:10000" ]]
ip1 link del veth1
ip1 link del wg1
ip2 link del wg2
+
+# Test that Netlink/IPC is working properly by doing things that usually cause split responses
+
+n0 $program wg0
+sleep 5
+config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" )
+for a in {1..255}; do
+ for b in {0..255}; do
+ config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" )
+ done
+done
+n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
+i=0
+for ip in $(n0 wg show wg0 allowed-ips); do
+ ((++i))
+done
+((i == 255*256*2+1))
+ip0 link del wg0
+
+n0 $program wg0
+config=( "[Interface]" "PrivateKey=$(wg genkey)" )
+for a in {1..40}; do
+ config+=( "[Peer]" "PublicKey=$(wg genkey)" )
+ for b in {1..52}; do
+ config+=( "AllowedIPs=$a.$b.0.0/16" )
+ done
+done
+n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
+i=0
+while read -r line; do
+ j=0
+ for ip in $line; do
+ ((++j))
+ done
+ ((j == 53))
+ ((++i))
+done < <(n0 wg show wg0 allowed-ips)
+((i == 40))
+ip0 link del wg0
+
+n0 $program wg0
+config=( )
+for i in {1..29}; do
+ config+=( "[Peer]" "PublicKey=$(wg genkey)" )
+done
+config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" )
+n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
+n0 wg showconf wg0 > /dev/null
+ip0 link del wg0
+
+! n0 wg show doesnotexist || false
+
+declare -A objects
+while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
+ [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
+ objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
+done < /dev/kmsg
+alldeleted=1
+for object in "${!objects[@]}"; do
+ if [[ ${objects["$object"]} != *createddestroyed ]]; then
+ echo "Error: $object: merely ${objects["$object"]}" >&3
+ alldeleted=0
+ fi
+done
+[[ $alldeleted -eq 1 ]]
+pretty "" "Objects that were created were also destroyed."
diff --git a/src/timers.go b/src/timers.go
index 99695ba..31165a3 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -279,34 +279,31 @@ func (peer *Peer) RoutineHandshakeInitiator() {
break AttemptHandshakes
}
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
-
- // marshal and send
+ // marshal handshake message
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.mac.AddMacs(packet)
- _, err = peer.SendBuffer(packet)
- if err != nil {
+ // send to endpoint
+
+ err = peer.SendBuffer(packet)
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
+ timeout := time.NewTimer(RekeyTimeout + jitter)
+ if err == nil {
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ logDebug.Println(
+ "Handshake initiation attempt",
+ attempts, "sent to", peer.String(),
+ )
+ } else {
logError.Println(
"Failed to send handshake initiation message to",
peer.String(), ":", err,
)
- continue
}
- peer.TimerAnyAuthenticatedPacketTraversal()
-
- // set handshake timeout
-
- timeout := time.NewTimer(RekeyTimeout + jitter)
- logDebug.Println(
- "Handshake initiation attempt",
- attempts, "sent to", peer.String(),
- )
-
// wait for handshake or timeout
select {
diff --git a/src/tun.go b/src/tun.go
index 8e8c759..5bdac0e 100644
--- a/src/tun.go
+++ b/src/tun.go
@@ -1,6 +1,7 @@
package main
import (
+ "os"
"sync/atomic"
)
@@ -15,6 +16,7 @@ const (
)
type TUNDevice interface {
+ File() *os.File // returns the file descriptor of the device
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
MTU() (int, error) // returns the MTU of the device
@@ -47,7 +49,7 @@ func (device *Device) RoutineTUNEventReader() {
if !device.tun.isUp.Get() {
logInfo.Println("Interface set up")
device.tun.isUp.Set(true)
- updateUDPConn(device)
+ UpdateUDPListener(device)
}
}
@@ -55,7 +57,7 @@ func (device *Device) RoutineTUNEventReader() {
if device.tun.isUp.Get() {
logInfo.Println("Interface set down")
device.tun.isUp.Set(false)
- closeUDPConn(device)
+ CloseUDPListener(device)
}
}
}
diff --git a/src/tun_linux.go b/src/tun_linux.go
index accc6c6..a728a48 100644
--- a/src/tun_linux.go
+++ b/src/tun_linux.go
@@ -56,6 +56,10 @@ type NativeTun struct {
events chan TUNEvent // device related events
}
+func (tun *NativeTun) File() *os.File {
+ return tun.fd
+}
+
func (tun *NativeTun) RoutineNetlinkListener() {
sock := int(C.bind_rtmgrp())
if sock < 0 {
@@ -222,7 +226,7 @@ func (tun *NativeTun) MTU() (int, error) {
val := binary.LittleEndian.Uint32(ifr[16:20])
if val >= (1 << 31) {
- return int(val-(1<<31)) - (1 << 31), nil
+ return int(toInt32(val)), nil
}
return int(val), nil
}
@@ -248,6 +252,29 @@ func (tun *NativeTun) Close() error {
return nil
}
+func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
+ device := &NativeTun{
+ fd: fd,
+ name: name,
+ events: make(chan TUNEvent, 5),
+ errors: make(chan error, 5),
+ }
+
+ // start event listener
+
+ var err error
+ device.index, err = getIFIndex(device.name)
+ if err != nil {
+ return nil, err
+ }
+
+ go device.RoutineNetlinkListener()
+
+ // set default MTU
+
+ return device, device.setMTU(DefaultMTU)
+}
+
func CreateTUN(name string) (TUNDevice, error) {
// open clone device
diff --git a/src/uapi.go b/src/uapi.go
index 326216b..dc8be66 100644
--- a/src/uapi.go
+++ b/src/uapi.go
@@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send("private_key=" + device.privateKey.ToHex())
}
- if device.net.addr != nil {
- send(fmt.Sprintf("listen_port=%d", device.net.addr.Port))
+ if device.net.port != 0 {
+ send(fmt.Sprintf("listen_port=%d", device.net.port))
}
+
if device.net.fwmark != 0 {
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
}
@@ -53,7 +54,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil {
- send("endpoint=" + peer.endpoint.String())
+ send("endpoint=" + peer.endpoint.DstToString())
}
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
@@ -134,56 +135,38 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "listen_port":
port, err := strconv.ParseUint(value, 10, 16)
if err != nil {
- logError.Println("Failed to set listen_port:", err)
+ logError.Println("Failed to parse listen_port:", err)
return &IPCError{Code: ipcErrorInvalid}
}
-
- addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
- if err != nil {
+ device.net.port = uint16(port)
+ if err := UpdateUDPListener(device); err != nil {
logError.Println("Failed to set listen_port:", err)
- return &IPCError{Code: ipcErrorInvalid}
+ return &IPCError{Code: ipcErrorPortInUse}
}
- device.net.mutex.Lock()
- device.net.addr = addr
- device.net.mutex.Unlock()
+ case "fwmark":
- err = updateUDPConn(device)
- if err != nil {
- logError.Println("Failed to set listen_port:", err)
- return &IPCError{Code: ipcErrorPortInUse}
- }
+ // parse fwmark field
- // TODO: Clear source address of all peers
+ fwmark, err := func() (uint32, error) {
+ if value == "" {
+ return 0, nil
+ }
+ mark, err := strconv.ParseUint(value, 10, 32)
+ return uint32(mark), err
+ }()
- case "fwmark":
- fwmark, err := strconv.ParseUint(value, 10, 32)
if err != nil {
logError.Println("Invalid fwmark", err)
return &IPCError{Code: ipcErrorInvalid}
}
device.net.mutex.Lock()
- if fwmark > 0 || device.net.fwmark > 0 {
- device.net.fwmark = uint32(fwmark)
- err := setMark(
- device.net.conn,
- device.net.fwmark,
- )
- if err != nil {
- logError.Println("Failed to set fwmark:", err)
- device.net.mutex.Unlock()
- return &IPCError{Code: ipcErrorIO}
- }
-
- // TODO: Clear source address of all peers
- }
+ device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock()
case "public_key":
-
// switch to peer configuration
-
deviceConfig = false
case "replace_peers":
@@ -218,7 +201,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.mutex.RLock()
if device.publicKey.Equals(pubKey) {
- // create dummy instance
+ // create dummy instance (not added to device)
peer = &Peer{}
dummy = true
@@ -244,6 +227,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "remove":
+
+ // remove currently selected peer from device
+
if value != "true" {
logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
@@ -256,6 +242,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
dummy = true
case "preshared_key":
+
+ // update PSK
+
peer.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
peer.mutex.Unlock()
@@ -265,15 +254,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "endpoint":
- addr, err := parseEndpoint(value)
+
+ // set endpoint destination
+
+ err := func() error {
+ peer.mutex.Lock()
+ defer peer.mutex.Unlock()
+ endpoint, err := CreateEndpoint(value)
+ if err != nil {
+ return err
+ }
+ peer.endpoint = endpoint
+ signalSend(peer.signal.handshakeReset)
+ return nil
+ }()
+
if err != nil {
logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalid}
}
- peer.mutex.Lock()
- peer.endpoint = addr
- peer.mutex.Unlock()
- signalSend(peer.signal.handshakeReset)
case "persistent_keepalive_interval":
diff --git a/src/uapi_linux.go b/src/uapi_linux.go
index cb9d858..f97a18a 100644
--- a/src/uapi_linux.go
+++ b/src/uapi_linux.go
@@ -10,12 +10,12 @@ import (
)
const (
- ipcErrorIO = -int64(unix.EIO)
- ipcErrorProtocol = -int64(unix.EPROTO)
- ipcErrorInvalid = -int64(unix.EINVAL)
- ipcErrorPortInUse = -int64(unix.EADDRINUSE)
- socketDirectory = "/var/run/wireguard"
- socketName = "%s.sock"
+ ipcErrorIO = -int64(unix.EIO)
+ ipcErrorProtocol = -int64(unix.EPROTO)
+ ipcErrorInvalid = -int64(unix.EINVAL)
+ ipcErrorPortInUse = -int64(unix.EADDRINUSE)
+ socketDirectory = "/var/run/wireguard"
+ socketName = "%s.sock"
)
type UAPIListener struct {
@@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
return nil
}
-func connectUnixSocket(path string) (net.Listener, error) {
+func UAPIListen(name string, file *os.File) (net.Listener, error) {
- // attempt inital connection
+ // wrap file in listener
- listener, err := net.Listen("unix", path)
- if err == nil {
- return listener, nil
- }
-
- // check if active
-
- _, err = net.Dial("unix", path)
- if err == nil {
- return nil, errors.New("Unix socket in use")
- }
-
- // attempt cleanup
-
- err = os.Remove(path)
- if err != nil {
- return nil, err
- }
-
- return net.Listen("unix", path)
-}
-
-func NewUAPIListener(name string) (net.Listener, error) {
-
- // check if path exist
-
- err := os.MkdirAll(socketDirectory, 077)
- if err != nil && !os.IsExist(err) {
- return nil, err
- }
-
- // open UNIX socket
-
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
-
- listener, err := connectUnixSocket(socketPath)
+ listener, err := net.FileListener(file)
if err != nil {
return nil, err
}
@@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
// watch for deletion of socket
+ socketPath := path.Join(
+ socketDirectory,
+ fmt.Sprintf(socketName, name),
+ )
+
uapi.inotifyFd, err = unix.InotifyInit()
if err != nil {
return nil, err
@@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
go func(l *UAPIListener) {
var buff [4096]byte
for {
- unix.Read(uapi.inotifyFd, buff[:])
+ // start with lstat to avoid race condition
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err
return
}
+ unix.Read(uapi.inotifyFd, buff[:])
}
}(uapi)
@@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
return uapi, nil
}
+
+func UAPIOpen(name string) (*os.File, error) {
+
+ // check if path exist
+
+ err := os.MkdirAll(socketDirectory, 0600)
+ if err != nil && !os.IsExist(err) {
+ return nil, err
+ }
+
+ // open UNIX socket
+
+ socketPath := path.Join(
+ socketDirectory,
+ fmt.Sprintf(socketName, name),
+ )
+
+ addr, err := net.ResolveUnixAddr("unix", socketPath)
+ if err != nil {
+ return nil, err
+ }
+
+ listener, err := func() (*net.UnixListener, error) {
+
+ // initial connection attempt
+
+ listener, err := net.ListenUnix("unix", addr)
+ if err == nil {
+ return listener, nil
+ }
+
+ // check if socket already active
+
+ _, err = net.Dial("unix", socketPath)
+ if err == nil {
+ return nil, errors.New("unix socket in use")
+ }
+
+ // cleanup & attempt again
+
+ err = os.Remove(socketPath)
+ if err != nil {
+ return nil, err
+ }
+ return net.ListenUnix("unix", addr)
+ }()
+
+ if err != nil {
+ return nil, err
+ }
+
+ return listener.File()
+}