diff options
Diffstat (limited to '')
-rw-r--r-- | conn/bind_linux.go (renamed from conn/conn_linux.go) | 108 |
1 files changed, 56 insertions, 52 deletions
diff --git a/conn/conn_linux.go b/conn/bind_linux.go index 58b7de1..4199809 100644 --- a/conn/conn_linux.go +++ b/conn/bind_linux.go @@ -1,5 +1,3 @@ -// +build !android - /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. @@ -18,55 +16,59 @@ import ( "golang.org/x/sys/unix" ) -type IPv4Source struct { +type ipv4Source struct { Src [4]byte Ifindex int32 } -type IPv6Source struct { +type ipv6Source struct { src [16]byte - //ifindex belongs in dst.ZoneId + // ifindex belongs in dst.ZoneId } -type NativeEndpoint struct { +type LinuxSocketEndpoint struct { sync.Mutex dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte - src [unsafe.Sizeof(IPv6Source{})]byte + src [unsafe.Sizeof(ipv6Source{})]byte isV6 bool } -func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() } -func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } -func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 } +func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() } +func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } +func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 } -func (endpoint *NativeEndpoint) src4() *IPv4Source { - return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) +func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source { + return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0])) } -func (endpoint *NativeEndpoint) src6() *IPv6Source { - return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) +func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source { + return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0])) } -func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { +func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 { return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) } -func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { +func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 { return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) } -type nativeBind struct { +// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. +type LinuxSocketBind struct { sock4 int sock6 int lastMark uint32 closing sync.RWMutex } -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = (*nativeBind)(nil) +func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } +func NewDefaultBind() Bind { return NewLinuxSocketBind() } + +var _ Endpoint = (*LinuxSocketEndpoint)(nil) +var _ Bind = (*LinuxSocketBind)(nil) -func CreateEndpoint(s string) (Endpoint, error) { - var end NativeEndpoint +func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { + var end LinuxSocketEndpoint addr, err := parseEndpoint(s) if err != nil { return nil, err @@ -97,14 +99,18 @@ func CreateEndpoint(s string) (Endpoint, error) { return &end, nil } - return nil, errors.New("Invalid IP address") + return nil, errors.New("invalid IP address") } -func createBind(port uint16) (Bind, uint16, error) { +func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) { var err error - var bind nativeBind var newPort uint16 var tries int + + if bind.sock4 != -1 || bind.sock6 != -1 { + return 0, ErrBindAlreadyOpen + } + originalPort := port again: @@ -113,7 +119,7 @@ again: bind.sock6, newPort, err = create6(port) if err != nil { if err != syscall.EAFNOSUPPORT { - return nil, 0, err + return 0, err } } else { port = newPort @@ -129,24 +135,19 @@ again: } if err != syscall.EAFNOSUPPORT { unix.Close(bind.sock6) - return nil, 0, err + return 0, err } } else { port = newPort } if bind.sock4 == -1 && bind.sock6 == -1 { - return nil, 0, errors.New("ipv4 and ipv6 not supported") + return 0, syscall.EAFNOSUPPORT } - - return &bind, port, nil -} - -func (bind *nativeBind) LastMark() uint32 { - return bind.lastMark + return port, nil } -func (bind *nativeBind) SetMark(value uint32) error { +func (bind *LinuxSocketBind) SetMark(value uint32) error { bind.closing.RLock() defer bind.closing.RUnlock() @@ -180,7 +181,7 @@ func (bind *nativeBind) SetMark(value uint32) error { return nil } -func (bind *nativeBind) Close() error { +func (bind *LinuxSocketBind) Close() error { var err1, err2 error bind.closing.RLock() if bind.sock6 != -1 { @@ -207,11 +208,11 @@ func (bind *nativeBind) Close() error { return err2 } -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { bind.closing.RLock() defer bind.closing.RUnlock() - var end NativeEndpoint + var end LinuxSocketEndpoint if bind.sock6 == -1 { return 0, nil, net.ErrClosed } @@ -223,11 +224,11 @@ func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { return n, &end, err } -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { bind.closing.RLock() defer bind.closing.RUnlock() - var end NativeEndpoint + var end LinuxSocketEndpoint if bind.sock4 == -1 { return 0, nil, net.ErrClosed } @@ -239,11 +240,14 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { return n, &end, err } -func (bind *nativeBind) Send(buff []byte, end Endpoint) error { +func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { bind.closing.RLock() defer bind.closing.RUnlock() - nend := end.(*NativeEndpoint) + nend, ok := end.(*LinuxSocketEndpoint) + if !ok { + return ErrWrongEndpointType + } if !nend.isV6 { if bind.sock4 == -1 { return net.ErrClosed @@ -257,7 +261,7 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error { } } -func (end *NativeEndpoint) SrcIP() net.IP { +func (end *LinuxSocketEndpoint) SrcIP() net.IP { if !end.isV6 { return net.IPv4( end.src4().Src[0], @@ -270,7 +274,7 @@ func (end *NativeEndpoint) SrcIP() net.IP { } } -func (end *NativeEndpoint) DstIP() net.IP { +func (end *LinuxSocketEndpoint) DstIP() net.IP { if !end.isV6 { return net.IPv4( end.dst4().Addr[0], @@ -283,7 +287,7 @@ func (end *NativeEndpoint) DstIP() net.IP { } } -func (end *NativeEndpoint) DstToBytes() []byte { +func (end *LinuxSocketEndpoint) DstToBytes() []byte { if !end.isV6 { return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] } else { @@ -291,11 +295,11 @@ func (end *NativeEndpoint) DstToBytes() []byte { } } -func (end *NativeEndpoint) SrcToString() string { +func (end *LinuxSocketEndpoint) SrcToString() string { return end.SrcIP().String() } -func (end *NativeEndpoint) DstToString() string { +func (end *LinuxSocketEndpoint) DstToString() string { var udpAddr net.UDPAddr udpAddr.IP = end.DstIP() if !end.isV6 { @@ -306,13 +310,13 @@ func (end *NativeEndpoint) DstToString() string { return udpAddr.String() } -func (end *NativeEndpoint) ClearDst() { +func (end *LinuxSocketEndpoint) ClearDst() { for i := range end.dst { end.dst[i] = 0 } } -func (end *NativeEndpoint) ClearSrc() { +func (end *LinuxSocketEndpoint) ClearSrc() { for i := range end.src { end.src[i] = 0 } @@ -427,7 +431,7 @@ func create6(port uint16) (int, uint16, error) { return fd, uint16(addr.Port), err } -func send4(sock int, end *NativeEndpoint, buff []byte) error { +func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { // construct message header @@ -467,7 +471,7 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error { return err } -func send6(sock int, end *NativeEndpoint, buff []byte) error { +func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { // construct message header @@ -511,7 +515,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error { return err } -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { +func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { // construct message header @@ -543,7 +547,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { return size, nil } -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { +func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { // construct message header |