aboutsummaryrefslogtreecommitdiff
path: root/device/receive.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/receive.go')
-rw-r--r--device/receive.go320
1 files changed, 178 insertions, 142 deletions
diff --git a/device/receive.go b/device/receive.go
index 03fcf00..aee7864 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -66,7 +66,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
+func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@@ -79,20 +79,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
// receive datagrams until conn is closed
- buffer := device.GetMessageBuffer()
-
var (
+ buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
+ buffs = make([][]byte, maxBatchSize)
err error
- size int
- endpoint conn.Endpoint
+ sizes = make([]int, maxBatchSize)
+ count int
+ endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
+ elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
)
- for {
- size, endpoint, err = recv(buffer[:])
+ for i := range buffsArrs {
+ buffsArrs[i] = device.GetMessageBuffer()
+ buffs[i] = buffsArrs[i][:]
+ }
+
+ defer func() {
+ for i := 0; i < maxBatchSize; i++ {
+ if buffsArrs[i] != nil {
+ device.PutMessageBuffer(buffsArrs[i])
+ }
+ }
+ }()
+ for {
+ count, err = recv(buffs, sizes, endpoints)
if err != nil {
- device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) {
return
}
@@ -103,101 +116,122 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
- buffer = device.GetMessageBuffer()
continue
}
return
}
deathSpiral = 0
- if size < MinMessageSize {
- continue
- }
+ // handle each packet in the batch
+ for i, size := range sizes[:count] {
+ if size < MinMessageSize {
+ continue
+ }
- // check size of packet
+ // check size of packet
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ packet := buffsArrs[i][:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- var okay bool
+ switch msgType {
- switch msgType {
+ // check if transport
- // check if transport
+ case MessageTransportType:
- case MessageTransportType:
+ // check size
- // check size
+ if len(packet) < MessageTransportSize {
+ continue
+ }
- if len(packet) < MessageTransportSize {
- continue
- }
+ // lookup key pair
- // lookup key pair
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
+ continue
+ }
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indexTable.Lookup(receiver)
- keypair := value.keypair
- if keypair == nil {
- continue
- }
+ // check keypair expiry
- // check keypair expiry
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ // create work element
+ peer := value.peer
+ elem := device.GetInboundElement()
+ elem.packet = packet
+ elem.buffer = buffsArrs[i]
+ elem.keypair = keypair
+ elem.endpoint = endpoints[i]
+ elem.counter = 0
+ elem.Mutex = sync.Mutex{}
+ elem.Lock()
+
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetInboundElementsSlice()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ *elemsForPeer = append(*elemsForPeer, elem)
+ buffsArrs[i] = device.GetMessageBuffer()
+ buffs[i] = buffsArrs[i][:]
continue
- }
-
- // create work element
- peer := value.peer
- elem := device.GetInboundElement()
- elem.packet = packet
- elem.buffer = buffer
- elem.keypair = keypair
- elem.endpoint = endpoint
- elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
- // add to decryption queues
- if peer.isRunning.Load() {
- peer.queue.inbound.c <- elem
- device.queue.decryption.c <- elem
- buffer = device.GetMessageBuffer()
- } else {
- device.PutInboundElement(elem)
- }
- continue
+ // otherwise it is a fixed size & handshake related packet
- // otherwise it is a fixed size & handshake related packet
-
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageInitiationType:
+ if len(packet) != MessageInitiationSize {
+ continue
+ }
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageResponseType:
+ if len(packet) != MessageResponseSize {
+ continue
+ }
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
+ case MessageCookieReplyType:
+ if len(packet) != MessageCookieReplySize {
+ continue
+ }
- default:
- device.log.Verbosef("Received message with unknown type")
- }
+ default:
+ device.log.Verbosef("Received message with unknown type")
+ continue
+ }
- if okay {
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
- buffer: buffer,
+ buffer: buffsArrs[i],
packet: packet,
- endpoint: endpoint,
+ endpoint: endpoints[i],
}:
- buffer = device.GetMessageBuffer()
+ buffsArrs[i] = device.GetMessageBuffer()
+ buffs[i] = buffsArrs[i][:]
default:
}
}
+ for peer, elems := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.queue.inbound.c <- elems
+ for _, elem := range *elems {
+ device.queue.decryption.c <- elem
+ }
+ } else {
+ for _, elem := range *elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsSlice(elems)
+ }
+ delete(elemsByPeer, peer)
+ }
}
}
@@ -393,7 +427,7 @@ func (device *Device) RoutineHandshake(id int) {
}
}
-func (peer *Peer) RoutineSequentialReceiver() {
+func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@@ -401,89 +435,91 @@ func (peer *Peer) RoutineSequentialReceiver() {
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
- for elem := range peer.queue.inbound.c {
- if elem == nil {
+ buffs := make([][]byte, 0, maxBatchSize)
+
+ for elems := range peer.queue.inbound.c {
+ if elems == nil {
return
}
- var err error
- elem.Lock()
- if elem.packet == nil {
- // decryption failed
- goto skip
- }
+ for _, elem := range *elems {
+ elem.Lock()
+ if elem.packet == nil {
+ // decryption failed
+ continue
+ }
- if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
- goto skip
- }
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
+ continue
+ }
- peer.SetEndpointFromPacket(elem.endpoint)
- if peer.ReceivedWithKeypair(elem.keypair) {
- peer.timersHandshakeComplete()
- peer.SendStagedPackets()
- }
+ peer.SetEndpointFromPacket(elem.endpoint)
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.timersHandshakeComplete()
+ peer.SendStagedPackets()
+ }
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
+ peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
- peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
+ if len(elem.packet) == 0 {
+ device.log.Verbosef("%v - Receiving keepalive packet", peer)
+ continue
+ }
+ peer.timersDataReceived()
- if len(elem.packet) == 0 {
- device.log.Verbosef("%v - Receiving keepalive packet", peer)
- goto skip
- }
- peer.timersDataReceived()
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
+ continue
+ }
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
- goto skip
- }
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
+ continue
+ }
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
- goto skip
+ default:
+ device.log.Verbosef("Packet with invalid IP version from %v", peer)
+ continue
}
- default:
- device.log.Verbosef("Packet with invalid IP version from %v", peer)
- goto skip
+ buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
-
- _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
- if err != nil && !device.isClosed() {
- device.log.Errorf("Failed to write packet to TUN device: %v", err)
- }
- if len(peer.queue.inbound.c) == 0 {
- err = device.tun.device.Flush()
- if err != nil {
- peer.device.log.Errorf("Unable to flush packets: %v", err)
+ if len(buffs) > 0 {
+ _, err := device.tun.device.Write(buffs, MessageTransportOffsetContent)
+ if err != nil && !device.isClosed() {
+ device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- skip:
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ for _, elem := range *elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ buffs = buffs[:0]
+ device.PutInboundElementsSlice(elems)
}
}