From a72b0f7ae5dda27d839bb317b7c01d11b215e77a Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sun, 8 Oct 2017 22:03:32 +0200 Subject: Added new UDPBind interface --- src/peer.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'src/peer.go') diff --git a/src/peer.go b/src/peer.go index 6fea829..791c091 100644 --- a/src/peer.go +++ b/src/peer.go @@ -4,7 +4,6 @@ import ( "encoding/base64" "errors" "fmt" - "net" "sync" "time" ) @@ -15,8 +14,8 @@ type Peer struct { persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake + endpoint Endpoint device *Device - endpoint *net.UDPAddr stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer @@ -134,7 +133,7 @@ func (peer *Peer) String() string { return fmt.Sprintf( "peer(%d %s %s)", peer.id, - peer.endpoint.String(), + peer.endpoint.DestinationToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } -- cgit v1.2.3 From e86d03dca23e5adcbd1c7bd30157bc7d19a932d7 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Mon, 16 Oct 2017 21:33:47 +0200 Subject: Initial implementation of source caching Yet untested. --- src/conn.go | 21 ++++++++++++++----- src/conn_linux.go | 12 +++++++---- src/device.go | 2 +- src/main.go | 2 -- src/peer.go | 24 ++++++++++++++++++--- src/receive.go | 18 +++++++--------- src/send.go | 19 +++++------------ src/timers.go | 2 +- src/tun.go | 4 ++-- src/uapi.go | 63 +++++++++++++++++++------------------------------------ 10 files changed, 84 insertions(+), 83 deletions(-) (limited to 'src/peer.go') diff --git a/src/conn.go b/src/conn.go index db4020d..012e24e 100644 --- a/src/conn.go +++ b/src/conn.go @@ -34,15 +34,20 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { return addr, err } -func ListeningUpdate(device *Device) error { +func UpdateUDPListener(device *Device) error { + device.mutex.Lock() + defer device.mutex.Unlock() + netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() // close existing sockets - if err := device.net.bind.Close(); err != nil { - return err + if netc.bind != nil { + if err := netc.bind.Close(); err != nil { + return err + } } // open new sockets @@ -64,13 +69,19 @@ func ListeningUpdate(device *Device) error { return err } - // TODO: clear endpoint (src) caches + // clear cached source addresses + + for _, peer := range device.peers { + peer.mutex.Lock() + peer.endpoint.value.ClearSrc() + peer.mutex.Unlock() + } } return nil } -func ListeningClose(device *Device) error { +func CloseUDPListener(device *Device) error { netc := &device.net netc.mutex.Lock() defer netc.mutex.Unlock() diff --git a/src/conn_linux.go b/src/conn_linux.go index 8942b03..4a5a3f0 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -133,7 +133,7 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string { } } -func (end *Endpoint) DestinationIP() net.IP { +func (end *Endpoint) DstIP() net.IP { switch end.dst.Family { case unix.AF_INET6: return end.dst.Addr[:] @@ -150,20 +150,24 @@ func (end *Endpoint) DestinationIP() net.IP { } } -func (end *Endpoint) SourceToBytes() []byte { +func (end *Endpoint) SrcToBytes() []byte { ptr := unsafe.Pointer(&end.src) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) return arr[:] } -func (end *Endpoint) SourceToString() string { +func (end *Endpoint) SrcToString() string { return sockaddrToString(end.src) } -func (end *Endpoint) DestinationToString() string { +func (end *Endpoint) DstToString() string { return sockaddrToString(end.dst) } +func (end *Endpoint) ClearDst() { + end.dst = unix.RawSockaddrInet6{} +} + func (end *Endpoint) ClearSrc() { end.src = unix.RawSockaddrInet6{} } diff --git a/src/device.go b/src/device.go index d1e0685..1aae448 100644 --- a/src/device.go +++ b/src/device.go @@ -205,7 +205,7 @@ func (device *Device) RemoveAllPeers() { func (device *Device) Close() { device.RemoveAllPeers() close(device.signal.stop) - ListeningClose(device) + CloseUDPListener(device) } func (device *Device) WaitChannel() chan struct{} { diff --git a/src/main.go b/src/main.go index a05dbba..5aaed9b 100644 --- a/src/main.go +++ b/src/main.go @@ -14,8 +14,6 @@ func printUsage() { } func main() { - test() - // parse arguments var foreground bool diff --git a/src/peer.go b/src/peer.go index 791c091..f24dcd8 100644 --- a/src/peer.go +++ b/src/peer.go @@ -14,9 +14,12 @@ type Peer struct { persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake - endpoint Endpoint device *Device - stats struct { + endpoint struct { + set bool // has a known endpoint been discovered + value Endpoint // source / destination cache + } + stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch @@ -105,6 +108,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) handshake.mutex.Unlock() + // reset endpoint + + peer.endpoint.set = false + peer.endpoint.value.ClearDst() + peer.endpoint.value.ClearSrc() + // prepare queuing peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) @@ -129,11 +138,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +/* Returns a short string identification for logging + */ func (peer *Peer) String() string { + if !peer.endpoint.set { + return fmt.Sprintf( + "peer(%d unknown %s)", + peer.id, + base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), + ) + } return fmt.Sprintf( "peer(%d %s %s)", peer.id, - peer.endpoint.DestinationToString(), + peer.endpoint.value.DstToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } diff --git a/src/receive.go b/src/receive.go index 664f1ba..1f05b2f 100644 --- a/src/receive.go +++ b/src/receive.go @@ -331,7 +331,7 @@ func (device *Device) RoutineHandshake() { return } - srcBytes := elem.endpoint.SourceToBytes() + srcBytes := elem.endpoint.SrcToBytes() if device.IsUnderLoad() { // verify MAC2 field @@ -340,8 +340,7 @@ func (device *Device) RoutineHandshake() { // construct cookie reply - logDebug.Println("Sending cookie reply to:", elem.endpoint.SourceToString()) - + logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString()) sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) if err != nil { @@ -365,9 +364,7 @@ func (device *Device) RoutineHandshake() { // check ratelimiter - if !device.ratelimiter.Allow( - elem.endpoint.DestinationIP(), - ) { + if !device.ratelimiter.Allow(elem.endpoint.DstIP()) { continue } } @@ -398,7 +395,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid initiation message from", - elem.endpoint.DestinationToString(), + elem.endpoint.DstToString(), ) continue } @@ -412,7 +409,8 @@ func (device *Device) RoutineHandshake() { // TODO: Discover destination address also, only update on change peer.mutex.Lock() - peer.endpoint = elem.endpoint + peer.endpoint.set = true + peer.endpoint.value = elem.endpoint peer.mutex.Unlock() // create response @@ -435,7 +433,7 @@ func (device *Device) RoutineHandshake() { // send response - _, err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet) if err == nil { peer.TimerAnyAuthenticatedPacketTraversal() } @@ -458,7 +456,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid response message from", - elem.endpoint.DestinationToString(), + elem.endpoint.DstToString(), ) continue } diff --git a/src/send.go b/src/send.go index 5c88ead..e37a736 100644 --- a/src/send.go +++ b/src/send.go @@ -105,24 +105,15 @@ func addToEncryptionQueue( } } -func (peer *Peer) SendBuffer(buffer []byte) (int, error) { +func (peer *Peer) SendBuffer(buffer []byte) error { peer.device.net.mutex.RLock() defer peer.device.net.mutex.RUnlock() - peer.mutex.RLock() defer peer.mutex.RUnlock() - - endpoint := peer.endpoint - if endpoint == nil { - return 0, errors.New("No known endpoint for peer") + if !peer.endpoint.set { + return errors.New("No known endpoint for peer") } - - conn := peer.device.net.conn - if conn == nil { - return 0, errors.New("No UDP socket for device") - } - - return conn.WriteToUDP(buffer, endpoint) + return peer.device.net.bind.Send(buffer, &peer.endpoint.value) } /* Reads packets from the TUN and inserts @@ -343,7 +334,7 @@ func (peer *Peer) RoutineSequentialSender() { // send message and return buffer to pool length := uint64(len(elem.packet)) - _, err := peer.SendBuffer(elem.packet) + err := peer.SendBuffer(elem.packet) device.PutMessageBuffer(elem.buffer) if err != nil { logDebug.Println("Failed to send authenticated packet to peer", peer.String()) diff --git a/src/timers.go b/src/timers.go index 99695ba..2a94005 100644 --- a/src/timers.go +++ b/src/timers.go @@ -288,7 +288,7 @@ func (peer *Peer) RoutineHandshakeInitiator() { packet := writer.Bytes() peer.mac.AddMacs(packet) - _, err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet) if err != nil { logError.Println( "Failed to send handshake initiation message to", diff --git a/src/tun.go b/src/tun.go index 8e8c759..9eed987 100644 --- a/src/tun.go +++ b/src/tun.go @@ -47,7 +47,7 @@ func (device *Device) RoutineTUNEventReader() { if !device.tun.isUp.Get() { logInfo.Println("Interface set up") device.tun.isUp.Set(true) - updateUDPConn(device) + UpdateUDPListener(device) } } @@ -55,7 +55,7 @@ func (device *Device) RoutineTUNEventReader() { if device.tun.isUp.Get() { logInfo.Println("Interface set down") device.tun.isUp.Set(false) - closeUDPConn(device) + CloseUDPListener(device) } } } diff --git a/src/uapi.go b/src/uapi.go index 7d08e56..2de26ee 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -39,9 +39,10 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { send("private_key=" + device.privateKey.ToHex()) } - if device.net.addr != nil { - send(fmt.Sprintf("listen_port=%d", device.net.addr.Port)) + if device.net.port != 0 { + send(fmt.Sprintf("listen_port=%d", device.net.port)) } + if device.net.fwmark != 0 { send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) } @@ -52,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { defer peer.mutex.RUnlock() send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.String()) + if peer.endpoint.set { + send("endpoint=" + peer.endpoint.value.DstToString()) } nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) @@ -137,53 +138,24 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorInvalid} } - - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) - if err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{Code: ipcErrorInvalid} - } - - device.net.mutex.Lock() - device.net.addr = addr - device.net.mutex.Unlock() - - err = updateUDPConn(device) - if err != nil { + device.net.port = uint16(port) + if err := UpdateUDPListener(device); err != nil { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorPortInUse} } - // TODO: Clear source address of all peers - case "fwmark": fwmark, err := strconv.ParseUint(value, 10, 32) if err != nil { logError.Println("Invalid fwmark", err) return &IPCError{Code: ipcErrorInvalid} } - device.net.mutex.Lock() - if fwmark > 0 || device.net.fwmark > 0 { - device.net.fwmark = uint32(fwmark) - err := SetMark( - device.net.conn, - device.net.fwmark, - ) - if err != nil { - logError.Println("Failed to set fwmark:", err) - device.net.mutex.Unlock() - return &IPCError{Code: ipcErrorIO} - } - - // TODO: Clear source address of all peers - } + device.net.fwmark = uint32(fwmark) device.net.mutex.Unlock() case "public_key": - // switch to peer configuration - deviceConfig = false case "replace_peers": @@ -218,7 +190,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.mutex.RLock() if device.publicKey.Equals(pubKey) { - // create dummy instance + // create dummy instance (not added to device) peer = &Peer{} dummy = true @@ -244,6 +216,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "remove": + + // remove currently selected peer from device + if value != "true" { logError.Println("Failed to set remove, invalid value:", value) return &IPCError{Code: ipcErrorInvalid} @@ -256,6 +231,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { dummy = true case "preshared_key": + + // update PSK + peer.mutex.Lock() err := peer.handshake.presharedKey.FromHex(value) peer.mutex.Unlock() @@ -265,14 +243,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } case "endpoint": - addr, err := parseEndpoint(value) + + // set endpoint destination and reset handshake timer + + peer.mutex.Lock() + err := peer.endpoint.value.Set(value) + peer.endpoint.set = (err == nil) + peer.mutex.Unlock() if err != nil { logError.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalid} } - peer.mutex.Lock() - peer.endpoint = addr - peer.mutex.Unlock() signalSend(peer.signal.handshakeReset) case "persistent_keepalive_interval": -- cgit v1.2.3 From 0485c34c8e20e4f7ea19bd3c3f52d2f4717caead Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 27 Oct 2017 10:43:37 +0200 Subject: Fixed message header length in conn_linux --- src/conn.go | 7 +++++++ src/conn_linux.go | 43 +++++++++++++++++++++++++++++++------------ src/main.go | 5 ++++- src/peer.go | 11 +++++++++++ src/send.go | 12 ------------ src/uapi.go | 2 +- 6 files changed, 54 insertions(+), 26 deletions(-) (limited to 'src/peer.go') diff --git a/src/conn.go b/src/conn.go index 012e24e..b2caffb 100644 --- a/src/conn.go +++ b/src/conn.go @@ -45,15 +45,20 @@ func UpdateUDPListener(device *Device) error { // close existing sockets if netc.bind != nil { + println("close bind") if err := netc.bind.Close(); err != nil { return err } + netc.bind = nil + println("closed") } // open new sockets if device.tun.isUp.Get() { + println("creat") + // bind to new port var err error @@ -69,6 +74,8 @@ func UpdateUDPListener(device *Device) error { return err } + println("okay") + // clear cached source addresses for _, peer := range device.peers { diff --git a/src/conn_linux.go b/src/conn_linux.go index 51ca4f3..8cda460 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -50,10 +50,12 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) { if err != nil { unix.Close(bind.sock6) } - return &bind, port, err + println(bind.sock6) + println(bind.sock4) + return bind, port, err } -func (bind *NativeBind) SetMark(value uint32) error { +func (bind NativeBind) SetMark(value uint32) error { err := unix.SetsockoptInt( bind.sock6, unix.SOL_SOCKET, @@ -73,7 +75,7 @@ func (bind *NativeBind) SetMark(value uint32) error { ) } -func (bind *NativeBind) Close() error { +func (bind NativeBind) Close() error { err1 := unix.Close(bind.sock6) err2 := unix.Close(bind.sock4) if err1 != nil { @@ -82,7 +84,7 @@ func (bind *NativeBind) Close() error { return err2 } -func (bind *NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { +func (bind NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { return receive6( bind.sock6, buff, @@ -90,7 +92,7 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { ) } -func (bind *NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { +func (bind NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { return receive4( bind.sock4, buff, @@ -98,7 +100,7 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { ) } -func (bind *NativeBind) Send(buff []byte, end *Endpoint) error { +func (bind NativeBind) Send(buff []byte, end *Endpoint) error { switch end.dst.Family { case unix.AF_INET6: return send6(bind.sock6, end, buff) @@ -236,7 +238,7 @@ func create6(port uint16) (int, uint16, error) { // create socket fd, err := unix.Socket( - unix.AF_INET, + unix.AF_INET6, unix.SOCK_DGRAM, 0, ) @@ -342,7 +344,7 @@ func send6(sock int, end *Endpoint, buff []byte) error { unix.Cmsghdr{ Level: unix.IPPROTO_IPV6, Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo, + Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, }, unix.Inet6Pktinfo{ Addr: end.src.Addr, @@ -368,15 +370,31 @@ func send6(sock int, end *Endpoint, buff []byte) error { uintptr(unsafe.Pointer(&msghdr)), 0, ) + + if errno == 0 { + return nil + } + + // clear src and retry + if errno == unix.EINVAL { end.ClearSrc() + cmsg.pktinfo = unix.Inet6Pktinfo{} + _, _, errno = unix.Syscall( + unix.SYS_SENDMSG, + uintptr(sock), + uintptr(unsafe.Pointer(&msghdr)), + 0, + ) } + return errno } func send4(sock int, end *Endpoint, buff []byte) error { println("send 4") println(end.DstToString()) + println(sock) // construct message header @@ -393,7 +411,7 @@ func send4(sock int, end *Endpoint, buff []byte) error { unix.Cmsghdr{ Level: unix.IPPROTO_IP, Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo, + Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, }, unix.Inet4Pktinfo{ Spec_dst: src4.src.Addr, @@ -419,10 +437,11 @@ func send4(sock int, end *Endpoint, buff []byte) error { 0, ) - println(sock) - fmt.Println(errno) + if errno == 0 { + return nil + } - // clear source cache and try again + // clear source and try again if errno == unix.EINVAL { end.ClearSrc() diff --git a/src/main.go b/src/main.go index 5aaed9b..05d56eb 100644 --- a/src/main.go +++ b/src/main.go @@ -84,7 +84,10 @@ func main() { logInfo := device.log.Info logError := device.log.Error - logInfo.Println("Starting device") + logDebug := device.log.Debug + + logInfo.Println("Device started") + logDebug.Println("Debug log enabled") // start configuration lister diff --git a/src/peer.go b/src/peer.go index f24dcd8..a98fc97 100644 --- a/src/peer.go +++ b/src/peer.go @@ -138,6 +138,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } +func (peer *Peer) SendBuffer(buffer []byte) error { + peer.device.net.mutex.RLock() + defer peer.device.net.mutex.RUnlock() + peer.mutex.RLock() + defer peer.mutex.RUnlock() + if !peer.endpoint.set { + return errors.New("No known endpoint for peer") + } + return peer.device.net.bind.Send(buffer, &peer.endpoint.value) +} + /* Returns a short string identification for logging */ func (peer *Peer) String() string { diff --git a/src/send.go b/src/send.go index e37a736..52872f6 100644 --- a/src/send.go +++ b/src/send.go @@ -2,7 +2,6 @@ package main import ( "encoding/binary" - "errors" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -105,17 +104,6 @@ func addToEncryptionQueue( } } -func (peer *Peer) SendBuffer(buffer []byte) error { - peer.device.net.mutex.RLock() - defer peer.device.net.mutex.RUnlock() - peer.mutex.RLock() - defer peer.mutex.RUnlock() - if !peer.endpoint.set { - return errors.New("No known endpoint for peer") - } - return peer.device.net.bind.Send(buffer, &peer.endpoint.value) -} - /* Reads packets from the TUN and inserts * into nonce queue for peer * diff --git a/src/uapi.go b/src/uapi.go index accffd1..5098e3d 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -135,7 +135,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "listen_port": port, err := strconv.ParseUint(value, 10, 16) if err != nil { - logError.Println("Failed to set listen_port:", err) + logError.Println("Failed to parse listen_port:", err) return &IPCError{Code: ipcErrorInvalid} } device.net.port = uint16(port) -- cgit v1.2.3 From d10126f883ad39567248540347b5469956ab8b2e Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 18 Nov 2017 23:34:02 +0100 Subject: Moved endpoint into interface and simplified peer --- src/conn.go | 20 ++++++++------ src/conn_linux.go | 83 ++++++++++++++++++++++++++++++++++--------------------- src/device.go | 6 ++-- src/peer.go | 19 +++++-------- src/receive.go | 29 ++++++++----------- src/uapi.go | 24 ++++++++++------ 6 files changed, 101 insertions(+), 80 deletions(-) (limited to 'src/peer.go') diff --git a/src/conn.go b/src/conn.go index 3cf00ab..74bb075 100644 --- a/src/conn.go +++ b/src/conn.go @@ -7,26 +7,28 @@ import ( "net" ) -type UDPBind interface { +/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic + */ +type Bind interface { SetMark(value uint32) error - ReceiveIPv6(buff []byte, end *Endpoint) (int, error) - ReceiveIPv4(buff []byte, end *Endpoint) (int, error) - Send(buff []byte, end *Endpoint) error + ReceiveIPv6(buff []byte) (int, Endpoint, error) + ReceiveIPv4(buff []byte) (int, Endpoint, error) + Send(buff []byte, end Endpoint) error Close() error } /* An Endpoint maintains the source/destination caching for a peer * - * dst : the remote address of a peer + * dst : the remote address of a peer ("endpoint" in uapi terminology) * src : the local address from which datagrams originate going to the peer - * */ -type UDPEndpoint interface { +type Endpoint interface { ClearSrc() // clears the source address ClearDst() // clears the destination address SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations + SetDst(string) error // used for manually setting the endpoint (uapi) DstIP() net.IP SrcIP() net.IP } @@ -107,7 +109,9 @@ func UpdateUDPListener(device *Device) error { for _, peer := range device.peers { peer.mutex.Lock() - peer.endpoint.value.ClearSrc() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } peer.mutex.Unlock() } diff --git a/src/conn_linux.go b/src/conn_linux.go index fb576b1..46f873f 100644 --- a/src/conn_linux.go +++ b/src/conn_linux.go @@ -21,22 +21,24 @@ import ( * See e.g. https://github.com/golang/go/issues/17930 * So this code is remains platform dependent. */ - -type Endpoint struct { +type NativeEndpoint struct { src unix.RawSockaddrInet6 dst unix.RawSockaddrInet6 } -type IPv4Source struct { - src unix.RawSockaddrInet4 - Ifindex int32 -} - type NativeBind struct { sock4 int sock6 int } +var _ Endpoint = (*NativeEndpoint)(nil) +var _ Bind = NativeBind{} + +type IPv4Source struct { + src unix.RawSockaddrInet4 + Ifindex int32 +} + func htons(val uint16) uint16 { var out [unsafe.Sizeof(val)]byte binary.BigEndian.PutUint16(out[:], val) @@ -48,7 +50,11 @@ func ntohs(val uint16) uint16 { return binary.BigEndian.Uint16((*tmp)[:]) } -func CreateUDPBind(port uint16) (UDPBind, uint16, error) { +func NewEndpoint() Endpoint { + return &NativeEndpoint{} +} + +func CreateUDPBind(port uint16) (Bind, uint16, error) { var err error var bind NativeBind @@ -99,28 +105,33 @@ func (bind NativeBind) Close() error { return err2 } -func (bind NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) { - return receive6( +func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + n, err := receive6( bind.sock6, buff, - end, + &end, ) + return n, &end, err } -func (bind NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) { - return receive4( +func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + n, err := receive4( bind.sock4, buff, - end, + &end, ) + return n, &end, err } -func (bind NativeBind) Send(buff []byte, end *Endpoint) error { - switch end.dst.Family { +func (bind NativeBind) Send(buff []byte, end Endpoint) error { + nend := end.(*NativeEndpoint) + switch nend.dst.Family { case unix.AF_INET6: - return send6(bind.sock6, end, buff) + return send6(bind.sock6, nend, buff) case unix.AF_INET: - return send4(bind.sock4, end, buff) + return send4(bind.sock4, nend, buff) default: return errors.New("Unknown address family of destination") } @@ -151,12 +162,12 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string { } } -func (end *Endpoint) DstIP() net.IP { - switch end.dst.Family { +func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP { + switch addr.Family { case unix.AF_INET6: - return end.dst.Addr[:] + return addr.Addr[:] case unix.AF_INET: - ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) + ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr)) return net.IPv4( ptr.Addr[0], ptr.Addr[1], @@ -168,25 +179,33 @@ func (end *Endpoint) DstIP() net.IP { } } -func (end *Endpoint) DstToBytes() []byte { +func (end *NativeEndpoint) SrcIP() net.IP { + return rawAddrToIP(end.src) +} + +func (end *NativeEndpoint) DstIP() net.IP { + return rawAddrToIP(end.dst) +} + +func (end *NativeEndpoint) DstToBytes() []byte { ptr := unsafe.Pointer(&end.src) arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) return arr[:] } -func (end *Endpoint) SrcToString() string { +func (end *NativeEndpoint) SrcToString() string { return sockaddrToString(end.src) } -func (end *Endpoint) DstToString() string { +func (end *NativeEndpoint) DstToString() string { return sockaddrToString(end.dst) } -func (end *Endpoint) ClearDst() { +func (end *NativeEndpoint) ClearDst() { end.dst = unix.RawSockaddrInet6{} } -func (end *Endpoint) ClearSrc() { +func (end *NativeEndpoint) ClearSrc() { end.src = unix.RawSockaddrInet6{} } @@ -306,7 +325,7 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } -func (end *Endpoint) SetDst(s string) error { +func (end *NativeEndpoint) SetDst(s string) error { addr, err := parseEndpoint(s) if err != nil { return err @@ -342,7 +361,7 @@ func (end *Endpoint) SetDst(s string) error { return errors.New("Failed to recognize IP address format") } -func send6(sock int, end *Endpoint, buff []byte) error { +func send6(sock int, end *NativeEndpoint, buff []byte) error { // construct message header @@ -404,7 +423,7 @@ func send6(sock int, end *Endpoint, buff []byte) error { return errno } -func send4(sock int, end *Endpoint, buff []byte) error { +func send4(sock int, end *NativeEndpoint, buff []byte) error { // construct message header @@ -470,7 +489,7 @@ func send4(sock int, end *Endpoint, buff []byte) error { return errno } -func receive4(sock int, buff []byte, end *Endpoint) (int, error) { +func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { // contruct message header @@ -518,7 +537,7 @@ func receive4(sock int, buff []byte, end *Endpoint) (int, error) { return int(size), nil } -func receive6(sock int, buff []byte, end *Endpoint) (int, error) { +func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { // contruct message header diff --git a/src/device.go b/src/device.go index 0085cee..76235bd 100644 --- a/src/device.go +++ b/src/device.go @@ -22,9 +22,9 @@ type Device struct { } net struct { mutex sync.RWMutex - bind UDPBind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) + bind Bind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) } mutex sync.RWMutex privateKey NoisePrivateKey diff --git a/src/peer.go b/src/peer.go index a98fc97..f3eb6c2 100644 --- a/src/peer.go +++ b/src/peer.go @@ -15,11 +15,8 @@ type Peer struct { keyPairs KeyPairs handshake Handshake device *Device - endpoint struct { - set bool // has a known endpoint been discovered - value Endpoint // source / destination cache - } - stats struct { + endpoint Endpoint + stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch @@ -110,9 +107,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // reset endpoint - peer.endpoint.set = false - peer.endpoint.value.ClearDst() - peer.endpoint.value.ClearSrc() + peer.endpoint = nil // prepare queuing @@ -143,16 +138,16 @@ func (peer *Peer) SendBuffer(buffer []byte) error { defer peer.device.net.mutex.RUnlock() peer.mutex.RLock() defer peer.mutex.RUnlock() - if !peer.endpoint.set { + if peer.endpoint == nil { return errors.New("No known endpoint for peer") } - return peer.device.net.bind.Send(buffer, &peer.endpoint.value) + return peer.device.net.bind.Send(buffer, peer.endpoint) } /* Returns a short string identification for logging */ func (peer *Peer) String() string { - if !peer.endpoint.set { + if peer.endpoint == nil { return fmt.Sprintf( "peer(%d unknown %s)", peer.id, @@ -162,7 +157,7 @@ func (peer *Peer) String() string { return fmt.Sprintf( "peer(%d %s %s)", peer.id, - peer.endpoint.value.DstToString(), + peer.endpoint.DstToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } diff --git a/src/receive.go b/src/receive.go index b8b06f7..27fdb8a 100644 --- a/src/receive.go +++ b/src/receive.go @@ -93,7 +93,7 @@ func (device *Device) addToHandshakeQueue( } } -func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) { +func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) { logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, IP version:", IP) @@ -104,20 +104,21 @@ func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) { buffer := device.GetMessageBuffer() - var size int - var err error + var ( + err error + size int + endpoint Endpoint + ) for { // read next datagram - var endpoint Endpoint - switch IP { case ipv4.Version: - size, err = bind.ReceiveIPv4(buffer[:], &endpoint) + size, endpoint, err = bind.ReceiveIPv4(buffer[:]) case ipv6.Version: - size, err = bind.ReceiveIPv6(buffer[:], &endpoint) + size, endpoint, err = bind.ReceiveIPv6(buffer[:]) default: return } @@ -339,10 +340,7 @@ func (device *Device) RoutineHandshake() { writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send( - writer.Bytes(), - &elem.endpoint, - ) + device.net.bind.Send(writer.Bytes(), elem.endpoint) if err != nil { logDebug.Println("Failed to send cookie reply:", err) } @@ -395,8 +393,7 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.mutex.Lock() - peer.endpoint.set = true - peer.endpoint.value = elem.endpoint + peer.endpoint = elem.endpoint peer.mutex.Unlock() // create response @@ -452,8 +449,7 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.mutex.Lock() - peer.endpoint.set = true - peer.endpoint.value = elem.endpoint + peer.endpoint = elem.endpoint peer.mutex.Unlock() logDebug.Println("Received handshake initation from", peer) @@ -527,8 +523,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // update endpoint peer.mutex.Lock() - peer.endpoint.set = true - peer.endpoint.value = elem.endpoint + peer.endpoint = elem.endpoint peer.mutex.Unlock() // check for keep-alive diff --git a/src/uapi.go b/src/uapi.go index e1d0929..670ecc4 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -53,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { defer peer.mutex.RUnlock() send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - if peer.endpoint.set { - send("endpoint=" + peer.endpoint.value.DstToString()) + if peer.endpoint != nil { + send("endpoint=" + peer.endpoint.DstToString()) } nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) @@ -255,17 +255,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "endpoint": - // set endpoint destination and reset handshake timer + // set endpoint destination + + err := func() error { + peer.mutex.Lock() + defer peer.mutex.Unlock() + + endpoint := NewEndpoint() + if err := endpoint.SetDst(value); err != nil { + return err + } + peer.endpoint = endpoint + signalSend(peer.signal.handshakeReset) + return nil + }() - peer.mutex.Lock() - err := peer.endpoint.value.SetDst(value) - peer.endpoint.set = (err == nil) - peer.mutex.Unlock() if err != nil { logError.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalid} } - signalSend(peer.signal.handshakeReset) case "persistent_keepalive_interval": -- cgit v1.2.3