aboutsummaryrefslogtreecommitdiff
path: root/conn
diff options
context:
space:
mode:
Diffstat (limited to 'conn')
-rw-r--r--conn/bind_linux.go47
-rw-r--r--conn/bind_std.go37
-rw-r--r--conn/bind_windows.go67
-rw-r--r--conn/bindtest/bindtest.go39
-rw-r--r--conn/conn.go24
-rw-r--r--conn/conn_test.go24
6 files changed, 171 insertions, 67 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index bd710ae..b6bc0dc 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -193,6 +193,10 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
return nil
}
+func (bind *LinuxSocketBind) BatchSize() int {
+ return 1
+}
+
func (bind *LinuxSocketBind) Close() error {
// Take a readlock to shut down the sockets...
bind.mu.RLock()
@@ -223,29 +227,39 @@ func (bind *LinuxSocketBind) Close() error {
return err2
}
-func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+func (bind *LinuxSocketBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock4 == -1 {
- return 0, nil, net.ErrClosed
+ return 0, net.ErrClosed
}
var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buf, &end)
- return n, &end, err
+ n, err := receive4(bind.sock4, buffs[0], &end)
+ if err != nil {
+ return 0, err
+ }
+ eps[0] = &end
+ sizes[0] = n
+ return 1, nil
}
-func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
+func (bind *LinuxSocketBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock6 == -1 {
- return 0, nil, net.ErrClosed
+ return 0, net.ErrClosed
}
var end LinuxSocketEndpoint
- n, err := receive6(bind.sock6, buf, &end)
- return n, &end, err
+ n, err := receive6(bind.sock6, buffs[0], &end)
+ if err != nil {
+ return 0, err
+ }
+ eps[0] = &end
+ sizes[0] = n
+ return 1, nil
}
-func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
+func (bind *LinuxSocketBind) Send(buffs [][]byte, end Endpoint) error {
nend, ok := end.(*LinuxSocketEndpoint)
if !ok {
return ErrWrongEndpointType
@@ -256,13 +270,24 @@ func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
if bind.sock4 == -1 {
return net.ErrClosed
}
- return send4(bind.sock4, nend, buff)
+ for _, buff := range buffs {
+ err := send4(bind.sock4, nend, buff)
+ if err != nil {
+ return err
+ }
+ }
} else {
if bind.sock6 == -1 {
return net.ErrClosed
}
- return send6(bind.sock6, nend, buff)
+ for _, buff := range buffs {
+ err := send6(bind.sock6, nend, buff)
+ if err != nil {
+ return err
+ }
+ }
}
+ return nil
}
func (end *LinuxSocketEndpoint) SrcIP() netip.Addr {
diff --git a/conn/bind_std.go b/conn/bind_std.go
index ae07aac..98fe23c 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -128,6 +128,10 @@ again:
return fns, uint16(port), nil
}
+func (bind *StdNetBind) BatchSize() int {
+ return 1
+}
+
func (bind *StdNetBind) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
@@ -150,20 +154,30 @@ func (bind *StdNetBind) Close() error {
}
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
- return n, asEndpoint(endpoint), err
+ return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
+ if err == nil {
+ sizes[0] = size
+ eps[0] = asEndpoint(endpoint)
+ return 1, nil
+ }
+ return 0, err
}
}
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
- return func(buff []byte) (int, Endpoint, error) {
- n, endpoint, err := conn.ReadFromUDPAddrPort(buff)
- return n, asEndpoint(endpoint), err
+ return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
+ size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
+ if err == nil {
+ sizes[0] = size
+ eps[0] = asEndpoint(endpoint)
+ return 1, nil
+ }
+ return 0, err
}
}
-func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
+func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
var err error
nend, ok := endpoint.(StdNetEndpoint)
if !ok {
@@ -186,8 +200,13 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
- _, err = conn.WriteToUDPAddrPort(buff, addrPort)
- return err
+ for _, buff := range buffs {
+ _, err = conn.WriteToUDPAddrPort(buff, addrPort)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index f8b187b..5a0b8c2 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -321,6 +321,11 @@ func (bind *WinRingBind) Close() error {
return nil
}
+func (bind *WinRingBind) BatchSize() int {
+ // TODO: implement batching in and out of the ring
+ return 1
+}
+
func (bind *WinRingBind) SetMark(mark uint32) error {
return nil
}
@@ -409,16 +414,22 @@ retry:
return n, &ep, nil
}
-func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
- return bind.v4.Receive(buf, &bind.isOpen)
+ n, ep, err := bind.v4.Receive(buffs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
}
-func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
- return bind.v6.Receive(buf, &bind.isOpen)
+ n, ep, err := bind.v6.Receive(buffs[0], &bind.isOpen)
+ sizes[0] = n
+ eps[0] = ep
+ return 1, err
}
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
@@ -473,32 +484,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
}
-func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
+func (bind *WinRingBind) Send(buffs [][]byte, endpoint Endpoint) error {
nend, ok := endpoint.(*WinRingEndpoint)
if !ok {
return ErrWrongEndpointType
}
bind.mu.RLock()
defer bind.mu.RUnlock()
- switch nend.family {
- case windows.AF_INET:
- if bind.v4.blackhole {
- return nil
- }
- return bind.v4.Send(buf, nend, &bind.isOpen)
- case windows.AF_INET6:
- if bind.v6.blackhole {
- return nil
+ for _, buf := range buffs {
+ switch nend.family {
+ case windows.AF_INET:
+ if bind.v4.blackhole {
+ continue
+ }
+ if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
+ case windows.AF_INET6:
+ if bind.v6.blackhole {
+ continue
+ }
+ if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
+ return err
+ }
}
- return bind.v6.Send(buf, nend, &bind.isOpen)
}
return nil
}
-func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
- sysconn, err := bind.ipv4.SyscallConn()
+func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv4.SyscallConn()
if err != nil {
return err
}
@@ -511,14 +528,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
if err != nil {
return err
}
- bind.blackhole4 = blackhole
+ s.blackhole4 = blackhole
return nil
}
-func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
- bind.mu.Lock()
- defer bind.mu.Unlock()
- sysconn, err := bind.ipv6.SyscallConn()
+func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ sysconn, err := s.ipv6.SyscallConn()
if err != nil {
return err
}
@@ -531,7 +548,7 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole
if err != nil {
return err
}
- bind.blackhole6 = blackhole
+ s.blackhole6 = blackhole
return nil
}
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index 9605a2a..b33c53d 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -89,32 +89,39 @@ func (c *ChannelBind) Close() error {
return nil
}
+func (c *ChannelBind) BatchSize() int { return 1 }
+
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
- return func(b []byte) (n int, ep conn.Endpoint, err error) {
+ return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
select {
case <-c.closeSignal:
- return 0, nil, net.ErrClosed
+ return 0, net.ErrClosed
case rx := <-ch:
- return copy(b, rx), c.target6, nil
+ copied := copy(buffs[0], rx)
+ sizes[0] = copied
+ eps[0] = c.target6
+ return 1, nil
}
}
}
-func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
- select {
- case <-c.closeSignal:
- return net.ErrClosed
- default:
- bc := make([]byte, len(b))
- copy(bc, b)
- if ep.(ChannelEndpoint) == c.target4 {
- *c.tx4 <- bc
- } else if ep.(ChannelEndpoint) == c.target6 {
- *c.tx6 <- bc
- } else {
- return os.ErrInvalid
+func (c *ChannelBind) Send(buffs [][]byte, ep conn.Endpoint) error {
+ for _, b := range buffs {
+ select {
+ case <-c.closeSignal:
+ return net.ErrClosed
+ default:
+ bc := make([]byte, len(b))
+ copy(bc, b)
+ if ep.(ChannelEndpoint) == c.target4 {
+ *c.tx4 <- bc
+ } else if ep.(ChannelEndpoint) == c.target6 {
+ *c.tx6 <- bc
+ } else {
+ return os.ErrInvalid
+ }
}
}
return nil
diff --git a/conn/conn.go b/conn/conn.go
index 497b92a..8c0a827 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -15,10 +15,17 @@ import (
"strings"
)
-// A ReceiveFunc receives a single inbound packet from the network.
-// It writes the data into b. n is the length of the packet.
-// ep is the remote endpoint.
-type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
+const (
+ DefaultBatchSize = 1 // maximum number of packets handled per read and write
+)
+
+// A ReceiveFunc receives at least one packet from the network and writes them
+// into packets. On a successful read it returns the number of elements of
+// sizes, packets, and endpoints that should be evaluated. Some elements of
+// sizes may be zero, and callers should ignore them. Callers must pass a sizes
+// and eps slice with a length greater than or equal to the length of packets.
+// These lengths must not exceed the length of the associated Bind.BatchSize().
+type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
@@ -38,11 +45,16 @@ type Bind interface {
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
- // Send writes a packet b to address ep.
- Send(b []byte, ep Endpoint) error
+ // Send writes one or more packets in buffs to address ep. The length of
+ // buffs must not exceed BatchSize().
+ Send(buffs [][]byte, ep Endpoint) error
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
+
+ // BatchSize is the number of buffers expected to be passed to
+ // the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
+ BatchSize() int
}
// BindSocketToInterface is implemented by Bind objects that support being
diff --git a/conn/conn_test.go b/conn/conn_test.go
new file mode 100644
index 0000000..7a6231d
--- /dev/null
+++ b/conn/conn_test.go
@@ -0,0 +1,24 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "testing"
+)
+
+func TestPrettyName(t *testing.T) {
+ var (
+ recvFunc ReceiveFunc = func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
+ )
+
+ const want = "TestPrettyName"
+
+ t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
+ if got := recvFunc.PrettyName(); got != want {
+ t.Errorf("PrettyName() = %v, want %v", got, want)
+ }
+ })
+}