aboutsummaryrefslogtreecommitdiff
path: root/conn/bind_linux.go
diff options
context:
space:
mode:
authorJordan Whited <jordan@tailscale.com>2023-03-02 15:08:28 -0800
committerJason A. Donenfeld <Jason@zx2c4.com>2023-03-10 14:52:17 +0100
commit9e2f3860220280a5630971478b53c8ad9a991ca8 (patch)
tree218f1bd9a8dd649a8fdb50571a921d1ccff4cae5 /conn/bind_linux.go
parent3bb8fec7e41fcc2138ddb4cba3f46100814fc523 (diff)
downloadwireguard-go-9e2f3860220280a5630971478b53c8ad9a991ca8.tar.gz
wireguard-go-9e2f3860220280a5630971478b53c8ad9a991ca8.zip
conn, device, tun: implement vectorized I/O on Linux
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which is made possible by virtio extensions in the kernel's TUN driver. Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind. conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All platforms now fall under conn.StdNetBind, except for Windows, which remains in conn.WinRingBind, which still needs to be adjusted to handle multiple packets. Also refactor sticky sockets support to eventually be applicable on platforms other than just Linux. However Linux remains the sole platform that fully implements it for now. Co-authored-by: James Tucker <james@tailscale.com> Signed-off-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'conn/bind_linux.go')
-rw-r--r--conn/bind_linux.go587
1 files changed, 0 insertions, 587 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
deleted file mode 100644
index b6bc0dc..0000000
--- a/conn/bind_linux.go
+++ /dev/null
@@ -1,587 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
- */
-
-package conn
-
-import (
- "errors"
- "net"
- "net/netip"
- "strconv"
- "sync"
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/unix"
-)
-
-type ipv4Source struct {
- Src [4]byte
- Ifindex int32
-}
-
-type ipv6Source struct {
- src [16]byte
- // ifindex belongs in dst.ZoneId
-}
-
-type LinuxSocketEndpoint struct {
- mu sync.Mutex
- dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
- src [unsafe.Sizeof(ipv6Source{})]byte
- isV6 bool
-}
-
-func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() }
-func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
-func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 }
-
-func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
- return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source {
- return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 {
- return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
- return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
-}
-
-// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
-type LinuxSocketBind struct {
- // mu guards sock4 and sock6 and the associated fds.
- // As long as someone holds mu (read or write), the associated fds are valid.
- mu sync.RWMutex
- sock4 int
- sock6 int
-}
-
-func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
-func NewDefaultBind() Bind { return NewLinuxSocketBind() }
-
-var (
- _ Endpoint = (*LinuxSocketEndpoint)(nil)
- _ Bind = (*LinuxSocketBind)(nil)
-)
-
-func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
- var end LinuxSocketEndpoint
- e, err := netip.ParseAddrPort(s)
- if err != nil {
- return nil, err
- }
-
- if e.Addr().Is4() {
- dst := end.dst4()
- end.isV6 = false
- dst.Port = int(e.Port())
- dst.Addr = e.Addr().As4()
- end.ClearSrc()
- return &end, nil
- }
-
- if e.Addr().Is6() {
- zone, err := zoneToUint32(e.Addr().Zone())
- if err != nil {
- return nil, err
- }
- dst := end.dst6()
- end.isV6 = true
- dst.Port = int(e.Port())
- dst.ZoneId = zone
- dst.Addr = e.Addr().As16()
- end.ClearSrc()
- return &end, nil
- }
-
- return nil, errors.New("invalid IP address")
-}
-
-func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
- bind.mu.Lock()
- defer bind.mu.Unlock()
-
- var err error
- var newPort uint16
- var tries int
-
- if bind.sock4 != -1 || bind.sock6 != -1 {
- return nil, 0, ErrBindAlreadyOpen
- }
-
- originalPort := port
-
-again:
- port = originalPort
- var sock4, sock6 int
- // Attempt ipv6 bind, update port if successful.
- sock6, newPort, err = create6(port)
- if err != nil {
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- // Attempt ipv4 bind, update port if successful.
- sock4, newPort, err = create4(port)
- if err != nil {
- if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
- unix.Close(sock6)
- tries++
- goto again
- }
- if !errors.Is(err, syscall.EAFNOSUPPORT) {
- unix.Close(sock6)
- return nil, 0, err
- }
- } else {
- port = newPort
- }
-
- var fns []ReceiveFunc
- if sock4 != -1 {
- bind.sock4 = sock4
- fns = append(fns, bind.receiveIPv4)
- }
- if sock6 != -1 {
- bind.sock6 = sock6
- fns = append(fns, bind.receiveIPv6)
- }
- if len(fns) == 0 {
- return nil, 0, syscall.EAFNOSUPPORT
- }
- return fns, port, nil
-}
-
-func (bind *LinuxSocketBind) SetMark(value uint32) error {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
-
- if bind.sock6 != -1 {
- err := unix.SetsockoptInt(
- bind.sock6,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- if bind.sock4 != -1 {
- err := unix.SetsockoptInt(
- bind.sock4,
- unix.SOL_SOCKET,
- unix.SO_MARK,
- int(value),
- )
- if err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (bind *LinuxSocketBind) BatchSize() int {
- return 1
-}
-
-func (bind *LinuxSocketBind) Close() error {
- // Take a readlock to shut down the sockets...
- bind.mu.RLock()
- if bind.sock6 != -1 {
- unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
- }
- if bind.sock4 != -1 {
- unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
- }
- bind.mu.RUnlock()
- // ...and a write lock to close the fd.
- // This ensures that no one else is using the fd.
- bind.mu.Lock()
- defer bind.mu.Unlock()
- var err1, err2 error
- if bind.sock6 != -1 {
- err1 = unix.Close(bind.sock6)
- bind.sock6 = -1
- }
- if bind.sock4 != -1 {
- err2 = unix.Close(bind.sock4)
- bind.sock4 = -1
- }
-
- if err1 != nil {
- return err1
- }
- return err2
-}
-
-func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock4 == -1 {
- return 0, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buffs[0], &end)
- if err != nil {
- return 0, err
- }
- eps[0] = &end
- sizes[0] = n
- return 1, nil
-}
-
-func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if bind.sock6 == -1 {
- return 0, net.ErrClosed
- }
- var end LinuxSocketEndpoint
- n, err := receive6(bind.sock6, buffs[0], &end)
- if err != nil {
- return 0, err
- }
- eps[0] = &end
- sizes[0] = n
- return 1, nil
-}
-
-func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error {
- nend, ok := end.(*LinuxSocketEndpoint)
- if !ok {
- return ErrWrongEndpointType
- }
- bind.mu.RLock()
- defer bind.mu.RUnlock()
- if !nend.isV6 {
- if bind.sock4 == -1 {
- return net.ErrClosed
- }
- for _, buff := range buffs {
- err := send4(bind.sock4, nend, buff)
- if err != nil {
- return err
- }
- }
- } else {
- if bind.sock6 == -1 {
- return net.ErrClosed
- }
- for _, buff := range buffs {
- err := send6(bind.sock6, nend, buff)
- if err != nil {
- return err
- }
- }
- }
- return nil
-}
-
-func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.src4().Src)
- } else {
- return netip.AddrFrom16(end.src6().src)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstIP() netip.Addr {
- if !end.isV6 {
- return netip.AddrFrom4(end.dst4().Addr)
- } else {
- return netip.AddrFrom16(end.dst6().Addr)
- }
-}
-
-func (end *LinuxSocketEndpoint) DstToBytes() []byte {
- if !end.isV6 {
- return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
- } else {
- return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
- }
-}
-
-func (end *LinuxSocketEndpoint) SrcToString() string {
- return end.SrcIP().String()
-}
-
-func (end *LinuxSocketEndpoint) DstToString() string {
- var port int
- if !end.isV6 {
- port = end.dst4().Port
- } else {
- port = end.dst6().Port
- }
- return netip.AddrPortFrom(end.DstIP(), uint16(port)).String()
-}
-
-func (end *LinuxSocketEndpoint) ClearDst() {
- for i := range end.dst {
- end.dst[i] = 0
- }
-}
-
-func (end *LinuxSocketEndpoint) ClearSrc() {
- for i := range end.src {
- end.src[i] = 0
- }
-}
-
-func zoneToUint32(zone string) (uint32, error) {
- if zone == "" {
- return 0, nil
- }
- if intr, err := net.InterfaceByName(zone); err == nil {
- return uint32(intr.Index), nil
- }
- n, err := strconv.ParseUint(zone, 10, 32)
- return uint32(n), err
-}
-
-func create4(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- addr := unix.SockaddrInet4{
- Port: int(port),
- }
-
- // set sockopts and bind
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IP,
- unix.IP_PKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet4).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func create6(port uint16) (int, uint16, error) {
- // create socket
-
- fd, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
- 0,
- )
- if err != nil {
- return -1, 0, err
- }
-
- // set sockopts and bind
-
- addr := unix.SockaddrInet6{
- Port: int(port),
- }
-
- if err := func() error {
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_RECVPKTINFO,
- 1,
- ); err != nil {
- return err
- }
-
- if err := unix.SetsockoptInt(
- fd,
- unix.IPPROTO_IPV6,
- unix.IPV6_V6ONLY,
- 1,
- ); err != nil {
- return err
- }
-
- return unix.Bind(fd, &addr)
- }(); err != nil {
- unix.Close(fd)
- return -1, 0, err
- }
-
- sa, err := unix.Getsockname(fd)
- if err == nil {
- addr.Port = sa.(*unix.SockaddrInet6).Port
- }
-
- return fd, uint16(addr.Port), err
-}
-
-func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IP,
- Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet4Pktinfo{
- Spec_dst: end.src4().Src,
- Ifindex: end.src4().Ifindex,
- },
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error {
- // construct message header
-
- cmsg := struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }{
- unix.Cmsghdr{
- Level: unix.IPPROTO_IPV6,
- Type: unix.IPV6_PKTINFO,
- Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
- },
- unix.Inet6Pktinfo{
- Addr: end.src6().src,
- Ifindex: end.dst6().ZoneId,
- },
- }
-
- if cmsg.pktinfo.Addr == [16]byte{} {
- cmsg.pktinfo.Ifindex = 0
- }
-
- end.mu.Lock()
- _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
-
- if err == nil {
- return nil
- }
-
- // clear src and retry
-
- if err == unix.EINVAL {
- end.ClearSrc()
- cmsg.pktinfo = unix.Inet6Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
- end.mu.Unlock()
- }
-
- return err
-}
-
-func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = false
-
- if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
- *end.dst4() = *newDst4
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
- end.src4().Src = cmsg.pktinfo.Spec_dst
- end.src4().Ifindex = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}
-
-func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) {
- // construct message header
-
- var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
- }
-
- size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
- if err != nil {
- return 0, err
- }
- end.isV6 = true
-
- if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
- *end.dst6() = *newDst6
- }
-
- // update source cache
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
- cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
- end.src6().src = cmsg.pktinfo.Addr
- end.dst6().ZoneId = cmsg.pktinfo.Ifindex
- }
-
- return size, nil
-}