aboutsummaryrefslogtreecommitdiff
path: root/conn/bind_linux.go
diff options
context:
space:
mode:
Diffstat (limited to 'conn/bind_linux.go')
-rw-r--r--conn/bind_linux.go109
1 files changed, 54 insertions, 55 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index 70ea609..9eec384 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
type LinuxSocketBind struct {
- sock4 int
- sock6 int
- lastMark uint32
- closing sync.RWMutex
+ // mu guards sock4 and sock6 and the associated fds.
+ // As long as someone holds mu (read or write), the associated fds are valid.
+ mu sync.RWMutex
+ sock4 int
+ sock6 int
}
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
@@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
return nil, errors.New("invalid IP address")
}
-func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) {
+func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+
var err error
var newPort uint16
var tries int
if bind.sock4 != -1 || bind.sock6 != -1 {
- return 0, ErrBindAlreadyOpen
+ return nil, 0, ErrBindAlreadyOpen
}
originalPort := port
again:
port = originalPort
+ var sock4, sock6 int
// Attempt ipv6 bind, update port if successful.
- bind.sock6, newPort, err = create6(port)
+ sock6, newPort, err = create6(port)
if err != nil {
- if err != syscall.EAFNOSUPPORT {
- return 0, err
+ if !errors.Is(err, syscall.EAFNOSUPPORT) {
+ return nil, 0, err
}
} else {
port = newPort
}
// Attempt ipv4 bind, update port if successful.
- bind.sock4, newPort, err = create4(port)
+ sock4, newPort, err = create4(port)
if err != nil {
- if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 {
- unix.Close(bind.sock6)
+ if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
+ unix.Close(sock6)
tries++
goto again
}
- if err != syscall.EAFNOSUPPORT {
- unix.Close(bind.sock6)
- return 0, err
+ if !errors.Is(err, syscall.EAFNOSUPPORT) {
+ unix.Close(sock6)
+ return nil, 0, err
}
} else {
port = newPort
}
- if bind.sock4 == -1 && bind.sock6 == -1 {
- return 0, syscall.EAFNOSUPPORT
+ var fns []ReceiveFunc
+ if sock4 != -1 {
+ fns = append(fns, makeReceiveIPv4(sock4))
+ bind.sock4 = sock4
+ }
+ if sock6 != -1 {
+ fns = append(fns, makeReceiveIPv6(sock6))
+ bind.sock6 = sock6
+ }
+ if len(fns) == 0 {
+ return nil, 0, syscall.EAFNOSUPPORT
}
- return port, nil
+ return fns, port, nil
}
func (bind *LinuxSocketBind) SetMark(value uint32) error {
- bind.closing.RLock()
- defer bind.closing.RUnlock()
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
@@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
}
}
- bind.lastMark = value
return nil
}
func (bind *LinuxSocketBind) Close() error {
- var err1, err2 error
- bind.closing.RLock()
+ // Take a readlock to shut down the sockets...
+ bind.mu.RLock()
if bind.sock6 != -1 {
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
}
if bind.sock4 != -1 {
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
}
- bind.closing.RUnlock()
- bind.closing.Lock()
+ bind.mu.RUnlock()
+ // ...and a write lock to close the fd.
+ // This ensures that no one else is using the fd.
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ var err1, err2 error
if bind.sock6 != -1 {
err1 = unix.Close(bind.sock6)
bind.sock6 = -1
@@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error {
err2 = unix.Close(bind.sock4)
bind.sock4 = -1
}
- bind.closing.Unlock()
if err1 != nil {
return err1
@@ -208,46 +224,29 @@ func (bind *LinuxSocketBind) Close() error {
return err2
}
-func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
- bind.closing.RLock()
- defer bind.closing.RUnlock()
-
- var end LinuxSocketEndpoint
- if bind.sock6 == -1 {
- return 0, nil, net.ErrClosed
+func makeReceiveIPv6(sock int) ReceiveFunc {
+ return func(buff []byte) (int, Endpoint, error) {
+ var end LinuxSocketEndpoint
+ n, err := receive6(sock, buff, &end)
+ return n, &end, err
}
- n, err := receive6(
- bind.sock6,
- buff,
- &end,
- )
- return n, &end, err
}
-func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
- bind.closing.RLock()
- defer bind.closing.RUnlock()
-
- var end LinuxSocketEndpoint
- if bind.sock4 == -1 {
- return 0, nil, net.ErrClosed
+func makeReceiveIPv4(sock int) ReceiveFunc {
+ return func(buff []byte) (int, Endpoint, error) {
+ var end LinuxSocketEndpoint
+ n, err := receive4(sock, buff, &end)
+ return n, &end, err
}
- n, err := receive4(
- bind.sock4,
- buff,
- &end,
- )
- return n, &end, err
}
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
- bind.closing.RLock()
- defer bind.closing.RUnlock()
-
nend, ok := end.(*LinuxSocketEndpoint)
if !ok {
return ErrWrongEndpointType
}
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
if !nend.isV6 {
if bind.sock4 == -1 {
return net.ErrClosed