diff options
Diffstat (limited to '')
-rw-r--r-- | conn/bind_std.go (renamed from conn/conn_default.go) | 81 |
1 files changed, 45 insertions, 36 deletions
diff --git a/conn/conn_default.go b/conn/bind_std.go index 82a1e42..193c4fe 100644 --- a/conn/conn_default.go +++ b/conn/bind_std.go @@ -1,5 +1,3 @@ -// +build !linux android - /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. @@ -13,41 +11,40 @@ import ( "syscall" ) -/* This code is meant to be a temporary solution - * on platforms for which the sticky socket / source caching behavior - * has not yet been implemented. - * - * See conn_linux.go for an implementation on the linux platform. - */ - -type nativeBind struct { +// StdNetBind is meant to be a temporary solution on platforms for which +// the sticky socket / source caching behavior has not yet been implemented. +// It uses the Go's net package to implement networking. +// See LinuxSocketBind for a proper implementation on the Linux platform. +type StdNetBind struct { ipv4 *net.UDPConn ipv6 *net.UDPConn blackhole4 bool blackhole6 bool } -type NativeEndpoint net.UDPAddr +func NewStdNetBind() Bind { return &StdNetBind{} } -var _ Bind = (*nativeBind)(nil) -var _ Endpoint = (*NativeEndpoint)(nil) +type StdNetEndpoint net.UDPAddr -func CreateEndpoint(s string) (Endpoint, error) { +var _ Bind = (*StdNetBind)(nil) +var _ Endpoint = (*StdNetEndpoint)(nil) + +func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { addr, err := parseEndpoint(s) - return (*NativeEndpoint)(addr), err + return (*StdNetEndpoint)(addr), err } -func (*NativeEndpoint) ClearSrc() {} +func (*StdNetEndpoint) ClearSrc() {} -func (e *NativeEndpoint) DstIP() net.IP { +func (e *StdNetEndpoint) DstIP() net.IP { return (*net.UDPAddr)(e).IP } -func (e *NativeEndpoint) SrcIP() net.IP { +func (e *StdNetEndpoint) SrcIP() net.IP { return nil // not supported } -func (e *NativeEndpoint) DstToBytes() []byte { +func (e *StdNetEndpoint) DstToBytes() []byte { addr := (*net.UDPAddr)(e) out := addr.IP.To4() if out == nil { @@ -58,11 +55,11 @@ func (e *NativeEndpoint) DstToBytes() []byte { return out } -func (e *NativeEndpoint) DstToString() string { +func (e *StdNetEndpoint) DstToString() string { return (*net.UDPAddr)(e).String() } -func (e *NativeEndpoint) SrcToString() string { +func (e *StdNetEndpoint) SrcToString() string { return "" } @@ -84,41 +81,52 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn, uaddr.Port, nil } -func createBind(uport uint16) (Bind, uint16, error) { +func (bind *StdNetBind) Open(uport uint16) (uint16, error) { var err error - var bind nativeBind var tries int + if bind.ipv4 != nil || bind.ipv6 != nil { + return 0, ErrBindAlreadyOpen + } + again: port := int(uport) bind.ipv4, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - return nil, 0, err + bind.ipv4 = nil + return 0, err } bind.ipv6, port, err = listenNet("udp6", port) if uport == 0 && err != nil && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { bind.ipv4.Close() + bind.ipv4 = nil + bind.ipv6 = nil tries++ goto again } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { bind.ipv4.Close() bind.ipv4 = nil - return nil, 0, err + bind.ipv6 = nil + return 0, err } - - return &bind, uint16(port), nil + if bind.ipv4 == nil && bind.ipv6 == nil { + return 0, syscall.EAFNOSUPPORT + } + return uint16(port), nil } -func (bind *nativeBind) Close() error { +func (bind *StdNetBind) Close() error { var err1, err2 error if bind.ipv4 != nil { err1 = bind.ipv4.Close() + bind.ipv4 = nil } if bind.ipv6 != nil { err2 = bind.ipv6.Close() + bind.ipv6 = nil } if err1 != nil { return err1 @@ -126,9 +134,7 @@ func (bind *nativeBind) Close() error { return err2 } -func (bind *nativeBind) LastMark() uint32 { return 0 } - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { if bind.ipv4 == nil { return 0, nil, syscall.EAFNOSUPPORT } @@ -136,20 +142,23 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { if endpoint != nil { endpoint.IP = endpoint.IP.To4() } - return n, (*NativeEndpoint)(endpoint), err + return n, (*StdNetEndpoint)(endpoint), err } -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { if bind.ipv6 == nil { return 0, nil, syscall.EAFNOSUPPORT } n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err + return n, (*StdNetEndpoint)(endpoint), err } -func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error { +func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { var err error - nend := endpoint.(*NativeEndpoint) + nend, ok := endpoint.(*StdNetEndpoint) + if !ok { + return ErrWrongEndpointType + } if nend.IP.To4() != nil { if bind.ipv4 == nil { return syscall.EAFNOSUPPORT |