aboutsummaryrefslogtreecommitdiff
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to 'device')
-rw-r--r--device/channels.go30
-rw-r--r--device/device.go27
-rw-r--r--device/device_test.go69
-rw-r--r--device/peer.go26
-rw-r--r--device/pools.go32
-rw-r--r--device/pools_test.go48
-rw-r--r--device/receive.go320
-rw-r--r--device/send.go270
8 files changed, 562 insertions, 260 deletions
diff --git a/device/channels.go b/device/channels.go
index 1bfeeaf..039d8df 100644
--- a/device/channels.go
+++ b/device/channels.go
@@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
}
type autodrainingInboundQueue struct {
- c chan *QueueInboundElement
+ c chan *[]*QueueInboundElement
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
- c: make(chan *QueueInboundElement, QueueInboundSize),
+ c: make(chan *[]*QueueInboundElement, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
@@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ case elems := <-q.c:
+ for _, elem := range *elems {
+ elem.Lock()
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsSlice(elems)
default:
return
}
@@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
}
type autodrainingOutboundQueue struct {
- c chan *QueueOutboundElement
+ c chan *[]*QueueOutboundElement
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
- c: make(chan *QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *[]*QueueOutboundElement, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
@@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ case elems := <-q.c:
+ for _, elem := range *elems {
+ elem.Lock()
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsSlice(elems)
default:
return
}
diff --git a/device/device.go b/device/device.go
index 3368a93..091c8d4 100644
--- a/device/device.go
+++ b/device/device.go
@@ -68,9 +68,11 @@ type Device struct {
cookieChecker CookieChecker
pool struct {
- messageBuffers *WaitPool
- inboundElements *WaitPool
- outboundElements *WaitPool
+ outboundElementsSlice *WaitPool
+ inboundElementsSlice *WaitPool
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
@@ -295,6 +297,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
+
device.PopulatePools()
// create queues
@@ -322,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
return device
}
+// BatchSize returns the BatchSize for the device as a whole which is the max of
+// the bind batch size and the tun batch size. The batch size reported by device
+// is the size used to construct memory pools, and is the allowed batch size for
+// the lifetime of the device.
+func (device *Device) BatchSize() int {
+ size := device.net.bind.BatchSize()
+ dSize := device.tun.device.BatchSize()
+ if size < dSize {
+ size = dSize
+ }
+ return size
+}
+
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
@@ -472,11 +488,13 @@ func (device *Device) BindUpdate() error {
var err error
var recvFns []conn.ReceiveFunc
netc := &device.net
+
recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
}
+
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
@@ -507,8 +525,9 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
- go device.RoutineReceiveIncoming(fn)
+ go device.RoutineReceiveIncoming(batchSize, fn)
}
device.log.Verbosef("UDP bind has been updated")
diff --git a/device/device_test.go b/device/device_test.go
index 975da64..73891bf 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -12,6 +12,7 @@ import (
"io"
"math/rand"
"net/netip"
+ "os"
"runtime"
"runtime/pprof"
"sync"
@@ -21,6 +22,7 @@ import (
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
+ "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
@@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) {
}
})
+ // Perform bind updates and keepalive sends concurrently with tunnel use.
+ t.Run("bindUpdate and keepalive", func(t *testing.T) {
+ const iters = 10
+ for i := 0; i < iters; i++ {
+ for _, peer := range pair {
+ peer.dev.BindUpdate()
+ peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
+ }
+ }
+ })
+
close(done)
}
@@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) {
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
})
}
+
+type fakeBindSized struct {
+ size int
+}
+
+func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ return nil, 0, nil
+}
+func (b *fakeBindSized) Close() error { return nil }
+func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
+func (b *fakeBindSized) Send(buffs [][]byte, ep conn.Endpoint) error { return nil }
+func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
+func (b *fakeBindSized) BatchSize() int { return b.size }
+
+type fakeTUNDeviceSized struct {
+ size int
+}
+
+func (t *fakeTUNDeviceSized) File() *os.File { return nil }
+func (t *fakeTUNDeviceSized) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
+ return 0, nil
+}
+func (t *fakeTUNDeviceSized) Write(buffs [][]byte, offset int) (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
+func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
+func (t *fakeTUNDeviceSized) Close() error { return nil }
+func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
+
+func TestBatchSize(t *testing.T) {
+ d := Device{}
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 1, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+}
diff --git a/device/peer.go b/device/peer.go
index 0e7b669..0ac4896 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -45,9 +45,9 @@ type Peer struct {
}
queue struct {
- staged chan *QueueOutboundElement // staged packets before a handshake is available
- outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
- inbound *autodrainingInboundQueue // sequential ordering of tun writing
+ staged chan *[]*QueueOutboundElement // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
@@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
- peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
+ peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
@@ -108,7 +108,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return peer, nil
}
-func (peer *Peer) SendBuffer(buffer []byte) error {
+func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
@@ -123,9 +123,13 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return errors.New("no known endpoint for peer")
}
- err := peer.device.net.bind.Send(buffer, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, peer.endpoint)
if err == nil {
- peer.txBytes.Add(uint64(len(buffer)))
+ var totalLen uint64
+ for _, b := range buffers {
+ totalLen += uint64(len(b))
+ }
+ peer.txBytes.Add(totalLen)
}
return err
}
@@ -187,8 +191,12 @@ func (peer *Peer) Start() {
device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound)
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
+
+ // Use the device batch size, not the bind batch size, as the device size is
+ // the size of the batch pools.
+ batchSize := peer.device.BatchSize()
+ go peer.RoutineSequentialSender(batchSize)
+ go peer.RoutineSequentialReceiver(batchSize)
peer.isRunning.Store(true)
}
diff --git a/device/pools.go b/device/pools.go
index 239757f..02a5d6a 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -46,6 +46,14 @@ func (p *WaitPool) Put(x any) {
}
func (device *Device) PopulatePools() {
+ device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueOutboundElement, 0, device.BatchSize())
+ return &s
+ })
+ device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueInboundElement, 0, device.BatchSize())
+ return &s
+ })
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
@@ -57,6 +65,30 @@ func (device *Device) PopulatePools() {
})
}
+func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement {
+ return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement)
+}
+
+func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) {
+ for i := range *s {
+ (*s)[i] = nil
+ }
+ *s = (*s)[:0]
+ device.pool.outboundElementsSlice.Put(s)
+}
+
+func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement {
+ return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement)
+}
+
+func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) {
+ for i := range *s {
+ (*s)[i] = nil
+ }
+ *s = (*s)[:0]
+ device.pool.inboundElementsSlice.Put(s)
+}
+
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}
diff --git a/device/pools_test.go b/device/pools_test.go
index 1502a29..82d7493 100644
--- a/device/pools_test.go
+++ b/device/pools_test.go
@@ -89,3 +89,51 @@ func BenchmarkWaitPool(b *testing.B) {
}
wg.Wait()
}
+
+func BenchmarkWaitPoolEmpty(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := NewWaitPool(0, func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkSyncPool(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := sync.Pool{New: func() any { return make([]byte, 16) }}
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
diff --git a/device/receive.go b/device/receive.go
index 03fcf00..aee7864 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -66,7 +66,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
+func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@@ -79,20 +79,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
// receive datagrams until conn is closed
- buffer := device.GetMessageBuffer()
-
var (
+ buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
+ buffs = make([][]byte, maxBatchSize)
err error
- size int
- endpoint conn.Endpoint
+ sizes = make([]int, maxBatchSize)
+ count int
+ endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
+ elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
)
- for {
- size, endpoint, err = recv(buffer[:])
+ for i := range buffsArrs {
+ buffsArrs[i] = device.GetMessageBuffer()
+ buffs[i] = buffsArrs[i][:]
+ }
+
+ defer func() {
+ for i := 0; i < maxBatchSize; i++ {
+ if buffsArrs[i] != nil {
+ device.PutMessageBuffer(buffsArrs[i])
+ }
+ }
+ }()
+ for {
+ count, err = recv(buffs, sizes, endpoints)
if err != nil {
- device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) {
return
}
@@ -103,101 +116,122 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
- buffer = device.GetMessageBuffer()
continue
}
return
}
deathSpiral = 0
- if size < MinMessageSize {
- continue
- }
+ // handle each packet in the batch
+ for i, size := range sizes[:count] {
+ if size < MinMessageSize {
+ continue
+ }
- // check size of packet
+ // check size of packet
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ packet := buffsArrs[i][:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- var okay bool
+ switch msgType {
- switch msgType {
+ // check if transport
- // check if transport
+ case MessageTransportType:
- case MessageTransportType:
+ // check size
- // check size
+ if len(packet) < MessageTransportSize {
+ continue
+ }
- if len(packet) < MessageTransportSize {
- continue
- }
+ // lookup key pair
- // lookup key pair
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
+ continue
+ }
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indexTable.Lookup(receiver)
- keypair := value.keypair
- if keypair == nil {
- continue
- }
+ // check keypair expiry
- // check keypair expiry
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ // create work element
+ peer := value.peer
+ elem := device.GetInboundElement()
+ elem.packet = packet
+ elem.buffer = buffsArrs[i]
+ elem.keypair = keypair
+ elem.endpoint = endpoints[i]
+ elem.counter = 0
+ elem.Mutex = sync.Mutex{}
+ elem.Lock()
+
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetInboundElementsSlice()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ *elemsForPeer = append(*elemsForPeer, elem)
+ buffsArrs[i] = device.GetMessageBuffer()
+ buffs[i] = buffsArrs[i][:]
continue
- }
-
- // create work element
- peer := value.peer
- elem := device.GetInboundElement()
- elem.packet = packet
- elem.buffer = buffer
- elem.keypair = keypair
- elem.endpoint = endpoint
- elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
- // add to decryption queues
- if peer.isRunning.Load() {
- peer.queue.inbound.c <- elem
- device.queue.decryption.c <- elem
- buffer = device.GetMessageBuffer()
- } else {
- device.PutInboundElement(elem)
- }
- continue
+ // otherwise it is a fixed size & handshake related packet
- // otherwise it is a fixed size & handshake related packet
-
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageInitiationType:
+ if len(packet) != MessageInitiationSize {
+ continue
+ }
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageResponseType:
+ if len(packet) != MessageResponseSize {
+ continue
+ }
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
+ case MessageCookieReplyType:
+ if len(packet) != MessageCookieReplySize {
+ continue
+ }
- default:
- device.log.Verbosef("Received message with unknown type")
- }
+ default:
+ device.log.Verbosef("Received message with unknown type")
+ continue
+ }
- if okay {
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
- buffer: buffer,
+ buffer: buffsArrs[i],
packet: packet,
- endpoint: endpoint,
+ endpoint: endpoints[i],
}:
- buffer = device.GetMessageBuffer()
+ buffsArrs[i] = device.GetMessageBuffer()
+ buffs[i] = buffsArrs[i][:]
default:
}
}
+ for peer, elems := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.queue.inbound.c <- elems
+ for _, elem := range *elems {
+ device.queue.decryption.c <- elem
+ }
+ } else {
+ for _, elem := range *elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsSlice(elems)
+ }
+ delete(elemsByPeer, peer)
+ }
}
}
@@ -393,7 +427,7 @@ func (device *Device) RoutineHandshake(id int) {
}
}
-func (peer *Peer) RoutineSequentialReceiver() {
+func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@@ -401,89 +435,91 @@ func (peer *Peer) RoutineSequentialReceiver() {
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
- for elem := range peer.queue.inbound.c {
- if elem == nil {
+ buffs := make([][]byte, 0, maxBatchSize)
+
+ for elems := range peer.queue.inbound.c {
+ if elems == nil {
return
}
- var err error
- elem.Lock()
- if elem.packet == nil {
- // decryption failed
- goto skip
- }
+ for _, elem := range *elems {
+ elem.Lock()
+ if elem.packet == nil {
+ // decryption failed
+ continue
+ }
- if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
- goto skip
- }
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
+ continue
+ }
- peer.SetEndpointFromPacket(elem.endpoint)
- if peer.ReceivedWithKeypair(elem.keypair) {
- peer.timersHandshakeComplete()
- peer.SendStagedPackets()
- }
+ peer.SetEndpointFromPacket(elem.endpoint)
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.timersHandshakeComplete()
+ peer.SendStagedPackets()
+ }
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
+ peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
- peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
+ if len(elem.packet) == 0 {
+ device.log.Verbosef("%v - Receiving keepalive packet", peer)
+ continue
+ }
+ peer.timersDataReceived()
- if len(elem.packet) == 0 {
- device.log.Verbosef("%v - Receiving keepalive packet", peer)
- goto skip
- }
- peer.timersDataReceived()
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
+ continue
+ }
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
- goto skip
- }
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
+ continue
+ }
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
- goto skip
+ default:
+ device.log.Verbosef("Packet with invalid IP version from %v", peer)
+ continue
}
- default:
- device.log.Verbosef("Packet with invalid IP version from %v", peer)
- goto skip
+ buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
-
- _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
- if err != nil && !device.isClosed() {
- device.log.Errorf("Failed to write packet to TUN device: %v", err)
- }
- if len(peer.queue.inbound.c) == 0 {
- err = device.tun.device.Flush()
- if err != nil {
- peer.device.log.Errorf("Unable to flush packets: %v", err)
+ if len(buffs) > 0 {
+ _, err := device.tun.device.Write(buffs, MessageTransportOffsetContent)
+ if err != nil && !device.isClosed() {
+ device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- skip:
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ for _, elem := range *elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ buffs = buffs[:0]
+ device.PutInboundElementsSlice(elems)
}
}
diff --git a/device/send.go b/device/send.go
index 854d172..b33b9f4 100644
--- a/device/send.go
+++ b/device/send.go
@@ -17,6 +17,7 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/tun"
)
/* Outbound flow
@@ -77,12 +78,15 @@ func (elem *QueueOutboundElement) clearPointers() {
func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
+ elems := peer.device.GetOutboundElementsSlice()
+ *elems = append(*elems, elem)
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elems:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
+ peer.device.PutOutboundElementsSlice(elems)
}
}
peer.SendStagedPackets()
@@ -125,7 +129,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@@ -163,7 +167,8 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ // TODO: allocation could be avoided
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
@@ -183,7 +188,8 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
var buff [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply)
- device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
+ // TODO: allocation could be avoided
+ device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil
}
@@ -198,11 +204,6 @@ func (peer *Peer) keepKeyFreshSending() {
}
}
-/* Reads packets from the TUN and inserts
- * into staged queue for peer
- *
- * Obs. Single instance per TUN device
- */
func (device *Device) RoutineReadFromTUN() {
defer func() {
device.log.Verbosef("Routine: TUN reader - stopped")
@@ -212,81 +213,123 @@ func (device *Device) RoutineReadFromTUN() {
device.log.Verbosef("Routine: TUN reader - started")
- var elem *QueueOutboundElement
+ var (
+ batchSize = device.BatchSize()
+ readErr error
+ elems = make([]*QueueOutboundElement, batchSize)
+ buffs = make([][]byte, batchSize)
+ elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize)
+ count = 0
+ sizes = make([]int, batchSize)
+ offset = MessageTransportHeaderSize
+ )
+
+ for i := range elems {
+ elems[i] = device.NewOutboundElement()
+ buffs[i] = elems[i].buffer[:]
+ }
- for {
- if elem != nil {
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ defer func() {
+ for _, elem := range elems {
+ if elem != nil {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
}
- elem = device.NewOutboundElement()
-
- // read packet
+ }()
- offset := MessageTransportHeaderSize
- size, err := device.tun.device.Read(elem.buffer[:], offset)
- if err != nil {
- if !device.isClosed() {
- if !errors.Is(err, os.ErrClosed) {
- device.log.Errorf("Failed to read packet from TUN device: %v", err)
- }
- go device.Close()
+ for {
+ // read packets
+ count, readErr = device.tun.device.Read(buffs, sizes, offset)
+ for i := 0; i < count; i++ {
+ if sizes[i] < 1 {
+ continue
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- return
- }
- if size == 0 || size > MaxContentSize {
- continue
- }
+ elem := elems[i]
+ elem.packet = buffs[i][offset : offset+sizes[i]]
- elem.packet = elem.buffer[offset : offset+size]
+ // lookup peer
+ var peer *Peer
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.allowedips.Lookup(dst)
- // lookup peer
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.allowedips.Lookup(dst)
- var peer *Peer
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- continue
+ default:
+ device.log.Verbosef("Received packet with unknown IP version")
}
- dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.allowedips.Lookup(dst)
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
+ if peer == nil {
continue
}
- dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.allowedips.Lookup(dst)
-
- default:
- device.log.Verbosef("Received packet with unknown IP version")
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetOutboundElementsSlice()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ *elemsForPeer = append(*elemsForPeer, elem)
+ elems[i] = device.NewOutboundElement()
+ buffs[i] = elems[i].buffer[:]
}
- if peer == nil {
- continue
+ for peer, elemsForPeer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.StagePackets(elemsForPeer)
+ peer.SendStagedPackets()
+ } else {
+ for _, elem := range *elemsForPeer {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsSlice(elemsForPeer)
+ }
+ delete(elemsByPeer, peer)
}
- if peer.isRunning.Load() {
- peer.StagePacket(elem)
- elem = nil
- peer.SendStagedPackets()
+
+ if readErr != nil {
+ if errors.Is(readErr, tun.ErrTooManySegments) {
+ // TODO: record stat for this
+ // This will happen if MSS is surprisingly small (< 576)
+ // coincident with reasonably high throughput.
+ device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
+ continue
+ }
+ if !device.isClosed() {
+ if !errors.Is(readErr, os.ErrClosed) {
+ device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
+ }
+ go device.Close()
+ }
+ return
}
}
}
-func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
+func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
for {
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elems:
return
default:
}
select {
case tooOld := <-peer.queue.staged:
- peer.device.PutMessageBuffer(tooOld.buffer)
- peer.device.PutOutboundElement(tooOld)
+ for _, elem := range *tooOld {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsSlice(tooOld)
default:
}
}
@@ -305,26 +348,55 @@ top:
}
for {
+ var elemsOOO *[]*QueueOutboundElement
select {
- case elem := <-peer.queue.staged:
- elem.peer = peer
- elem.nonce = keypair.sendNonce.Add(1) - 1
- if elem.nonce >= RejectAfterMessages {
- keypair.sendNonce.Store(RejectAfterMessages)
- peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
- goto top
+ case elems := <-peer.queue.staged:
+ i := 0
+ for _, elem := range *elems {
+ elem.peer = peer
+ elem.nonce = keypair.sendNonce.Add(1) - 1
+ if elem.nonce >= RejectAfterMessages {
+ keypair.sendNonce.Store(RejectAfterMessages)
+ if elemsOOO == nil {
+ elemsOOO = peer.device.GetOutboundElementsSlice()
+ }
+ *elemsOOO = append(*elemsOOO, elem)
+ continue
+ } else {
+ (*elems)[i] = elem
+ i++
+ }
+
+ elem.keypair = keypair
+ elem.Lock()
}
+ *elems = (*elems)[:i]
- elem.keypair = keypair
- elem.Lock()
+ if elemsOOO != nil {
+ peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans
+ }
+
+ if len(*elems) == 0 {
+ peer.device.PutOutboundElementsSlice(elems)
+ goto top
+ }
// add to parallel and sequential queue
if peer.isRunning.Load() {
- peer.queue.outbound.c <- elem
- peer.device.queue.encryption.c <- elem
+ peer.queue.outbound.c <- elems
+ for _, elem := range *elems {
+ peer.device.queue.encryption.c <- elem
+ }
} else {
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ for _, elem := range *elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsSlice(elems)
+ }
+
+ if elemsOOO != nil {
+ goto top
}
default:
return
@@ -335,9 +407,12 @@ top:
func (peer *Peer) FlushStagedPackets() {
for {
select {
- case elem := <-peer.queue.staged:
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ case elems := <-peer.queue.staged:
+ for _, elem := range *elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsSlice(elems)
default:
return
}
@@ -400,12 +475,7 @@ func (device *Device) RoutineEncryption(id int) {
}
}
-/* Sequentially reads packets from queue and sends to endpoint
- *
- * Obs. Single instance per peer.
- * The routine terminates then the outbound queue is closed.
- */
-func (peer *Peer) RoutineSequentialSender() {
+func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
@@ -413,36 +483,50 @@ func (peer *Peer) RoutineSequentialSender() {
}()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
- for elem := range peer.queue.outbound.c {
- if elem == nil {
+ buffs := make([][]byte, 0, maxBatchSize)
+
+ for elems := range peer.queue.outbound.c {
+ buffs = buffs[:0]
+ if elems == nil {
return
}
- elem.Lock()
if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed.
- // The timers and SendBuffer code are resilient to a few stragglers.
+ // The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ for _, elem := range *elems {
+ elem.Lock()
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
continue
}
+ dataSent := false
+ for _, elem := range *elems {
+ elem.Lock()
+ if len(elem.packet) != MessageKeepaliveSize {
+ dataSent = true
+ }
+ buffs = append(buffs, elem.packet)
+ }
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- // send message and return buffer to pool
-
- err := peer.SendBuffer(elem.packet)
- if len(elem.packet) != MessageKeepaliveSize {
+ err := peer.SendBuffers(buffs)
+ if dataSent {
peer.timersDataSent()
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ for _, elem := range *elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsSlice(elems)
if err != nil {
- device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
+ device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue
}