/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn import ( "context" "errors" "net" "net/netip" "runtime" "strconv" "sync" "syscall" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) var ( _ Bind = (*StdNetBind)(nil) ) // StdNetBind implements Bind for all platforms. While Windows has its own Bind // (see bind_windows.go), it may fall back to StdNetBind. // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable // methods for sending and receiving multiple datagrams per-syscall. See the // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. type StdNetBind struct { mu sync.Mutex // protects all fields except as specified ipv4 *net.UDPConn ipv6 *net.UDPConn ipv4PC *ipv4.PacketConn // will be nil on non-Linux ipv6PC *ipv6.PacketConn // will be nil on non-Linux // these three fields are not guarded by mu udpAddrPool sync.Pool ipv4MsgsPool sync.Pool ipv6MsgsPool sync.Pool blackhole4 bool blackhole6 bool } func NewStdNetBind() Bind { return &StdNetBind{ udpAddrPool: sync.Pool{ New: func() any { return &net.UDPAddr{ IP: make([]byte, 16), } }, }, ipv4MsgsPool: sync.Pool{ New: func() any { msgs := make([]ipv4.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) msgs[i].OOB = make([]byte, srcControlSize) } return &msgs }, }, ipv6MsgsPool: sync.Pool{ New: func() any { msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) msgs[i].OOB = make([]byte, srcControlSize) } return &msgs }, }, } } type StdNetEndpoint struct { // AddrPort is the endpoint destination. netip.AddrPort // src is the current sticky source address and interface index, if supported. src struct { netip.Addr ifidx int32 } } var ( _ Bind = (*StdNetBind)(nil) _ Endpoint = &StdNetEndpoint{} ) func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { e, err := netip.ParseAddrPort(s) if err != nil { return nil, err } return &StdNetEndpoint{ AddrPort: e, }, nil } func (e *StdNetEndpoint) ClearSrc() { e.src.ifidx = 0 e.src.Addr = netip.Addr{} } func (e *StdNetEndpoint) DstIP() netip.Addr { return e.AddrPort.Addr() } func (e *StdNetEndpoint) SrcIP() netip.Addr { return e.src.Addr } func (e *StdNetEndpoint) SrcIfidx() int32 { return e.src.ifidx } func (e *StdNetEndpoint) DstToBytes() []byte { b, _ := e.AddrPort.MarshalBinary() return b } func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } func (e *StdNetEndpoint) SrcToString() string { return e.src.Addr.String() } func listenNet(network string, port int) (*net.UDPConn, int, error) { conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } // Retrieve port. laddr := conn.LocalAddr() uaddr, err := net.ResolveUDPAddr( laddr.Network(), laddr.String(), ) if err != nil { return nil, 0, err } return conn.(*net.UDPConn), uaddr.Port, nil } func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { s.mu.Lock() defer s.mu.Unlock() var err error var tries int if s.ipv4 != nil || s.ipv6 != nil { return nil, 0, ErrBindAlreadyOpen } // Attempt to open ipv4 and ipv6 listeners on the same port. // If uport is 0, we can retry on failure. again: port := int(uport) var v4conn, v6conn *net.UDPConn var v4pc *ipv4.PacketConn var v6pc *ipv6.PacketConn v4conn, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. v6conn, port, err = listenNet("udp6", port) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { v4conn.Close() tries++ goto again } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { v4conn.Close() return nil, 0, err } var fns []ReceiveFunc if v4conn != nil { if runtime.GOOS == "linux" { v4pc = ipv4.NewPacketConn(v4conn) s.ipv4PC = v4pc } fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn)) s.ipv4 = v4conn } if v6conn != nil { if runtime.GOOS == "linux" { v6pc = ipv6.NewPacketConn(v6conn) s.ipv6PC = v6pc } fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn)) s.ipv6 = v6conn } if len(fns) == 0 { return nil, 0, syscall.EAFNOSUPPORT } return fns, uint16(port), nil } func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) defer s.ipv4MsgsPool.Put(msgs) for i := range bufs { (*msgs)[i].Buffers[0] = bufs[i] } var numMsgs int if runtime.GOOS == "linux" { numMsgs, err = pc.ReadBatch(*msgs, 0) if err != nil { return 0, err } } else { msg := &(*msgs)[0] msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) if err != nil { return 0, err } numMsgs = 1 } for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] sizes[i] = msg.N addrPort := msg.Addr.(*net.UDPAddr).AddrPort() ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation getSrcFromControl(msg.OOB[:msg.NN], ep) eps[i] = ep } return numMsgs, nil } } func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) defer s.ipv6MsgsPool.Put(msgs) for i := range bufs { (*msgs)[i].Buffers[0] = bufs[i] } var numMsgs int if runtime.GOOS == "linux" { numMsgs, err = pc.ReadBatch(*msgs, 0) if err != nil { return 0, err } } else { msg := &(*msgs)[0] msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) if err != nil { return 0, err } numMsgs = 1 } for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] sizes[i] = msg.N addrPort := msg.Addr.(*net.UDPAddr).AddrPort() ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation getSrcFromControl(msg.OOB[:msg.NN], ep) eps[i] = ep } return numMsgs, nil } } // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and // rename the IdealBatchSize constant to BatchSize. func (s *StdNetBind) BatchSize() int { if runtime.GOOS == "linux" { return IdealBatchSize } return 1 } func (s *StdNetBind) Close() error { s.mu.Lock() defer s.mu.Unlock() var err1, err2 error if s.ipv4 != nil { err1 = s.ipv4.Close() s.ipv4 = nil s.ipv4PC = nil } if s.ipv6 != nil { err2 = s.ipv6.Close() s.ipv6 = nil s.ipv6PC = nil } s.blackhole4 = false s.blackhole6 = false if err1 != nil { return err1 } return err2 } func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 var ( pc4 *ipv4.PacketConn pc6 *ipv6.PacketConn ) is6 := false if endpoint.DstIP().Is6() { blackhole = s.blackhole6 conn = s.ipv6 pc6 = s.ipv6PC is6 = true } else { pc4 = s.ipv4PC } s.mu.Unlock() if blackhole { return nil } if conn == nil { return syscall.EAFNOSUPPORT } if is6 { return s.send6(conn, pc6, endpoint, bufs) } else { return s.send4(conn, pc4, endpoint, bufs) } } func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, bufs [][]byte) error { ua := s.udpAddrPool.Get().(*net.UDPAddr) as4 := ep.DstIP().As4() copy(ua.IP, as4[:]) ua.IP = ua.IP[:4] ua.Port = int(ep.(*StdNetEndpoint).Port()) msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) for i, buf := range bufs { (*msgs)[i].Buffers[0] = buf (*msgs)[i].Addr = ua setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) } var ( n int err error start int ) if runtime.GOOS == "linux" { for { n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) if err != nil || n == len((*msgs)[start:len(bufs)]) { break } start += n } } else { for i, buf := range bufs { _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) if err != nil { break } } } s.udpAddrPool.Put(ua) s.ipv4MsgsPool.Put(msgs) return err } func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, bufs [][]byte) error { ua := s.udpAddrPool.Get().(*net.UDPAddr) as16 := ep.DstIP().As16() copy(ua.IP, as16[:]) ua.IP = ua.IP[:16] ua.Port = int(ep.(*StdNetEndpoint).Port()) msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) for i, buf := range bufs { (*msgs)[i].Buffers[0] = buf (*msgs)[i].Addr = ua setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) } var ( n int err error start int ) if runtime.GOOS == "linux" { for { n, err = pc.WriteBatch((*msgs)[start:len(bufs)], 0) if err != nil || n == len((*msgs)[start:len(bufs)]) { break } start += n } } else { for i, buf := range bufs { _, _, err = conn.WriteMsgUDP(buf, (*msgs)[i].OOB, ua) if err != nil { break } } } s.udpAddrPool.Put(ua) s.ipv6MsgsPool.Put(msgs) return err }