From eb75ff430d1f78e129bbfe49d612f241ca418df4 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Mon, 26 Jun 2017 22:07:29 +0200 Subject: Begin implementation of outbound work queue --- src/device.go | 3 ++ src/index.go | 75 ++++++++++++++++++------------ src/keypair.go | 16 ++++++- src/noise_protocol.go | 43 ++++++++++++++--- src/peer.go | 4 +- src/send.go | 124 +++++++++++++++++++++++++++++++------------------- 6 files changed, 181 insertions(+), 84 deletions(-) (limited to 'src') diff --git a/src/device.go b/src/device.go index ce10a63..4b8cda0 100644 --- a/src/device.go +++ b/src/device.go @@ -2,11 +2,14 @@ package main import ( "log" + "net" "sync" ) type Device struct { mtu int + source *net.UDPAddr // UDP source address + conn *net.UDPConn // UDP "connection" mutex sync.RWMutex peers map[NoisePublicKey]*Peer indices IndexTable diff --git a/src/index.go b/src/index.go index 81f71e9..9178510 100644 --- a/src/index.go +++ b/src/index.go @@ -11,10 +11,15 @@ import ( * */ +type IndexTableEntry struct { + peer *Peer + handshake *Handshake + keyPair *KeyPair +} + type IndexTable struct { - mutex sync.RWMutex - keypairs map[uint32]*KeyPair - handshakes map[uint32]*Peer + mutex sync.RWMutex + table map[uint32]IndexTableEntry } func randUint32() (uint32, error) { @@ -32,52 +37,66 @@ func randUint32() (uint32, error) { func (table *IndexTable) Init() { table.mutex.Lock() - defer table.mutex.Unlock() - table.keypairs = make(map[uint32]*KeyPair) - table.handshakes = make(map[uint32]*Peer) + table.table = make(map[uint32]IndexTableEntry) + table.mutex.Unlock() } -func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { +func (table *IndexTable) ClearIndex(index uint32) { + if index == 0 { + return + } + table.mutex.Lock() + delete(table.table, index) + table.mutex.Unlock() +} + +func (table *IndexTable) Insert(key uint32, value IndexTableEntry) { table.mutex.Lock() - defer table.mutex.Unlock() + table.table[key] = value + table.mutex.Unlock() +} + +func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) { for { // generate random index - id, err := randUint32() + index, err := randUint32() if err != nil { - return id, err + return index, err } - if id == 0 { + if index == 0 { continue } // check if index used - _, ok := table.keypairs[id] - if ok { - continue - } - _, ok = table.handshakes[id] + table.mutex.RLock() + _, ok := table.table[index] if ok { continue } + table.mutex.RUnlock() - // clean old index + // replace index - delete(table.handshakes, peer.handshake.localIndex) - table.handshakes[id] = peer - return id, nil + table.mutex.Lock() + _, found := table.table[index] + if found { + table.mutex.Unlock() + continue + } + table.table[index] = IndexTableEntry{ + peer: peer, + handshake: &peer.handshake, + keyPair: nil, + } + table.mutex.Unlock() + return index, nil } } -func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.keypairs[id] -} - -func (table *IndexTable) LookupHandshake(id uint32) *Peer { +func (table *IndexTable) Lookup(id uint32) IndexTableEntry { table.mutex.RLock() defer table.mutex.RUnlock() - return table.handshakes[id] + return table.table[id] } diff --git a/src/keypair.go b/src/keypair.go index e7961a8..53e123f 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -16,6 +16,18 @@ type KeyPairs struct { mutex sync.RWMutex current *KeyPair previous *KeyPair - next *KeyPair - newKeyPair chan bool + next *KeyPair // not yet "confirmed by transport" + newKeyPair chan bool // signals when "current" has been updated +} + +func (kp *KeyPairs) Init() { + kp.mutex.Lock() + kp.newKeyPair = make(chan bool, 5) + kp.mutex.Unlock() +} + +func (kp *KeyPairs) Current() *KeyPair { + kp.mutex.RLock() + defer kp.mutex.RUnlock() + return kp.current } diff --git a/src/noise_protocol.go b/src/noise_protocol.go index a16908a..bf1db9b 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -120,13 +120,15 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e return nil, err } + device.indices.ClearIndex(handshake.localIndex) + handshake.localIndex, err = device.indices.NewIndex(peer) + // assign index var msg MessageInitiation msg.Type = MessageInitiationType msg.Ephemeral = handshake.localEphemeral.publicKey() - handshake.localIndex, err = device.indices.NewIndex(peer) if err != nil { return nil, err @@ -249,6 +251,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error // assign index var err error + device.indices.ClearIndex(handshake.localIndex) handshake.localIndex, err = device.indices.NewIndex(peer) if err != nil { return nil, err @@ -299,11 +302,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // lookup handshake by reciever - peer := device.indices.LookupHandshake(msg.Reciever) - if peer == nil { + lookup := device.indices.Lookup(msg.Reciever) + handshake := lookup.handshake + if handshake == nil { return nil } - handshake := &peer.handshake + handshake.mutex.Lock() defer handshake.mutex.Unlock() if handshake.state != HandshakeInitiationCreated { @@ -345,7 +349,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { handshake.remoteIndex = msg.Sender handshake.state = HandshakeResponseConsumed - return peer + return lookup.peer } func (peer *Peer) NewKeyPair() *KeyPair { @@ -355,13 +359,16 @@ func (peer *Peer) NewKeyPair() *KeyPair { // derive keys + var isInitiator bool var sendKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte if handshake.state == HandshakeResponseConsumed { sendKey, recvKey = KDF2(handshake.chainKey[:], nil) + isInitiator = true } else if handshake.state == HandshakeResponseCreated { recvKey, sendKey = KDF2(handshake.chainKey[:], nil) + isInitiator = false } else { return nil } @@ -369,16 +376,40 @@ func (peer *Peer) NewKeyPair() *KeyPair { // create AEAD instances var keyPair KeyPair + keyPair.send, _ = chacha20poly1305.New(sendKey[:]) keyPair.recv, _ = chacha20poly1305.New(recvKey[:]) keyPair.sendNonce = 0 keyPair.recvNonce = 0 + // remap index + + peer.device.indices.Insert(handshake.localIndex, IndexTableEntry{ + peer: peer, + keyPair: &keyPair, + handshake: nil, + }) + handshake.localIndex = 0 + + // rotate key pairs + + func() { + kp := &peer.keyPairs + kp.mutex.Lock() + defer kp.mutex.Unlock() + if isInitiator { + kp.previous = peer.keyPairs.current + kp.current = &keyPair + kp.newKeyPair <- true + } else { + kp.next = &keyPair + } + }() + // zero handshake handshake.chainKey = [blake2s.Size]byte{} handshake.localEphemeral = NoisePrivateKey{} peer.handshake.state = HandshakeZeroed - return &keyPair } diff --git a/src/peer.go b/src/peer.go index 42b9e8d..6a879cb 100644 --- a/src/peer.go +++ b/src/peer.go @@ -14,8 +14,7 @@ const ( type Peer struct { mutex sync.RWMutex - endpointIP net.IP // - endpointPort uint16 // + endpoint *net.UDPAddr persistentKeepaliveInterval time.Duration // 0 = disabled keyPairs KeyPairs handshake Handshake @@ -35,6 +34,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer.mutex.Lock() peer.device = device + peer.keyPairs.Init() peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize) // map public key diff --git a/src/send.go b/src/send.go index 9790320..da5905d 100644 --- a/src/send.go +++ b/src/send.go @@ -1,9 +1,11 @@ package main import ( + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" "net" "sync" - "sync/atomic" + "time" ) /* Handles outbound flow @@ -70,85 +72,115 @@ func (device *Device) SendPacket(packet []byte) { * * TODO: avoid dynamic allocation of work queue elements */ -func (peer *Peer) ConsumeOutboundPackets() { +func (peer *Peer) RoutineOutboundNonceWorker() { + var packet []byte + var keyPair *KeyPair + var flushTimer time.Timer + for { - // wait for key pair - keyPair := func() *KeyPair { - peer.keyPairs.mutex.RLock() - defer peer.keyPairs.mutex.RUnlock() - return peer.keyPairs.current - }() - if keyPair == nil { - if len(peer.queueOutboundRouting) > 0 { - // TODO: start handshake - <-peer.keyPairs.newKeyPair - } - continue + + // wait for packet + + if packet == nil { + packet = <-peer.queueOutboundRouting } - // assign packets key pair - for { + // wait for key pair + + for keyPair == nil { + flushTimer.Reset(time.Second * 10) + // TODO: Handshake or NOP select { case <-peer.keyPairs.newKeyPair: - default: - case <-peer.keyPairs.newKeyPair: - case packet := <-peer.queueOutboundRouting: + keyPair = peer.keyPairs.Current() + continue + case <-flushTimer.C: + size := len(peer.queueOutboundRouting) + for i := 0; i < size; i += 1 { + <-peer.queueOutboundRouting + } + packet = nil + } + break + } + + // process current packet + + if packet != nil { - // create new work element + // create work element - work := new(OutboundWorkQueueElement) - work.wg.Add(1) - work.keyPair = keyPair - work.packet = packet - work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 + work := new(OutboundWorkQueueElement) + work.wg.Add(1) + work.keyPair = keyPair + work.packet = packet + work.nonce = keyPair.sendNonce - peer.queueOutbound <- work + packet = nil + peer.queueOutbound <- work + keyPair.sendNonce += 1 - // drop packets until there is room + // drop packets until there is space + func() { for { select { case peer.device.queueWorkOutbound <- work: - break + return default: drop := <-peer.device.queueWorkOutbound drop.packet = nil drop.wg.Done() } } - } + }() } } } +/* Go routine + * + * sequentially reads packets from queue and sends to endpoint + * + */ func (peer *Peer) RoutineSequential() { for work := range peer.queueOutbound { work.wg.Wait() + + // check if dropped ("ghost packet") + if work.packet == nil { continue } + + // + } } -func (device *Device) EncryptionWorker() { - for { - work := <-device.queueWorkOutbound - - func() { - defer work.wg.Done() +func (device *Device) RoutineEncryptionWorker() { + var nonce [chacha20poly1305.NonceSize]byte + for work := range device.queueWorkOutbound { + // pad packet - // pad packet - padding := device.mtu - len(work.packet) - if padding < 0 { - work.packet = nil - return - } - for n := 0; n < padding; n += 1 { - work.packet = append(work.packet, 0) // TODO: gotta be a faster way - } + padding := device.mtu - len(work.packet) + if padding < 0 { + work.packet = nil + work.wg.Done() + } + for n := 0; n < padding; n += 1 { + work.packet = append(work.packet, 0) + } - // + // encrypt - }() + binary.LittleEndian.PutUint64(nonce[4:], work.nonce) + work.packet = work.keyPair.send.Seal( + work.packet[:0], + nonce[:], + work.packet, + nil, + ) + work.wg.Done() } } -- cgit v1.2.3