From 9d806d3853c926df75e83966d2c4f832708a1b08 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Mon, 26 Jun 2017 13:14:02 +0200 Subject: Begin work on outbound packet flow --- src/cookie.go | 39 +++++++++++++ src/device.go | 20 ++++--- src/keypair.go | 13 ++++- src/main.go | 8 ++- src/noise_protocol.go | 106 ++++++++++++++++++---------------- src/noise_test.go | 4 +- src/peer.go | 43 +++++++++++--- src/routing.go | 23 -------- src/send.go | 154 ++++++++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 315 insertions(+), 95 deletions(-) create mode 100644 src/cookie.go create mode 100644 src/send.go diff --git a/src/cookie.go b/src/cookie.go new file mode 100644 index 0000000..a6987a2 --- /dev/null +++ b/src/cookie.go @@ -0,0 +1,39 @@ +package main + +import ( + "errors" + "golang.org/x/crypto/blake2s" +) + +func CalculateCookie(peer *Peer, msg []byte) { + size := len(msg) + + if size < blake2s.Size128*2 { + panic(errors.New("bug: message too short")) + } + + startMac1 := size - (blake2s.Size128 * 2) + startMac2 := size - blake2s.Size128 + + mac1 := msg[startMac1 : startMac1+blake2s.Size128] + mac2 := msg[startMac2 : startMac2+blake2s.Size128] + + peer.mutex.RLock() + defer peer.mutex.RUnlock() + + // set mac1 + + func() { + mac, _ := blake2s.New128(peer.macKey[:]) + mac.Write(msg[:startMac1]) + mac.Sum(mac1[:0]) + }() + + // set mac2 + + if peer.cookie != nil { + mac, _ := blake2s.New128(peer.cookie) + mac.Write(msg[:startMac2]) + mac.Sum(mac2[:0]) + } +} diff --git a/src/device.go b/src/device.go index 9969034..ce10a63 100644 --- a/src/device.go +++ b/src/device.go @@ -1,18 +1,22 @@ package main import ( + "log" "sync" ) type Device struct { - mutex sync.RWMutex - peers map[NoisePublicKey]*Peer - indices IndexTable - privateKey NoisePrivateKey - publicKey NoisePublicKey - fwMark uint32 - listenPort uint16 - routingTable RoutingTable + mtu int + mutex sync.RWMutex + peers map[NoisePublicKey]*Peer + indices IndexTable + privateKey NoisePrivateKey + publicKey NoisePublicKey + fwMark uint32 + listenPort uint16 + routingTable RoutingTable + logger log.Logger + queueWorkOutbound chan *OutboundWorkQueueElement } func (device *Device) SetPrivateKey(sk NoisePrivateKey) { diff --git a/src/keypair.go b/src/keypair.go index e434c74..e7961a8 100644 --- a/src/keypair.go +++ b/src/keypair.go @@ -2,11 +2,20 @@ package main import ( "crypto/cipher" + "sync" ) type KeyPair struct { recv cipher.AEAD - recvNonce NoiseNonce + recvNonce uint64 send cipher.AEAD - sendNonce NoiseNonce + sendNonce uint64 +} + +type KeyPairs struct { + mutex sync.RWMutex + current *KeyPair + previous *KeyPair + next *KeyPair + newKeyPair chan bool } diff --git a/src/main.go b/src/main.go index af336f0..b6f6deb 100644 --- a/src/main.go +++ b/src/main.go @@ -1,6 +1,8 @@ package main -import "fmt" +import ( + "fmt" +) func main() { fd, err := CreateTUN("test0") @@ -8,9 +10,9 @@ func main() { queue := make(chan []byte, 1000) - var device Device + // var device Device - go OutgoingRoutingWorker(&device, queue) + // go OutgoingRoutingWorker(&device, queue) for { tmp := make([]byte, 1<<16) diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 7f26cf1..a16908a 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -9,9 +9,9 @@ import ( ) const ( - HandshakeReset = iota - HandshakeInitialCreated - HandshakeInitialConsumed + HandshakeZeroed = iota + HandshakeInitiationCreated + HandshakeInitiationConsumed HandshakeResponseCreated HandshakeResponseConsumed ) @@ -24,13 +24,19 @@ const ( ) const ( - MessageInitalType = 1 + MessageInitiationType = 1 MessageResponseType = 2 MessageCookieResponseType = 3 MessageTransportType = 4 ) -type MessageInital struct { +/* Type is an 8-bit field, followed by 3 nul bytes, + * by marshalling the messages in little-endian byteorder + * we can treat these as a 32-bit int + * + */ + +type MessageInitiation struct { Type uint32 Sender uint32 Ephemeral NoisePublicKey @@ -73,9 +79,9 @@ type Handshake struct { } var ( - ZeroNonce [chacha20poly1305.NonceSize]byte InitalChainKey [blake2s.Size]byte InitalHash [blake2s.Size]byte + ZeroNonce [chacha20poly1305.NonceSize]byte ) func init() { @@ -83,23 +89,23 @@ func init() { InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) } -func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { +func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { return KDF1(c[:], data) } -func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { +func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { return blake2s.Sum256(append(h[:], data...)) } -func (h *Handshake) addToHash(data []byte) { - h.hash = addToHash(h.hash, data) +func (h *Handshake) mixHash(data []byte) { + h.hash = mixHash(h.hash, data) } -func (h *Handshake) addToChainKey(data []byte) { - h.chainKey = addToChainKey(h.chainKey, data) +func (h *Handshake) mixKey(data []byte) { + h.chainKey = mixKey(h.chainKey, data) } -func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { +func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() @@ -108,7 +114,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { var err error handshake.chainKey = InitalChainKey - handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:]) + handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:]) handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err @@ -116,9 +122,9 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { // assign index - var msg MessageInital + var msg MessageInitiation - msg.Type = MessageInitalType + msg.Type = MessageInitiationType msg.Ephemeral = handshake.localEphemeral.publicKey() handshake.localIndex, err = device.indices.NewIndex(peer) @@ -127,10 +133,10 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { } msg.Sender = handshake.localIndex - handshake.addToChainKey(msg.Ephemeral[:]) - handshake.addToHash(msg.Ephemeral[:]) + handshake.mixKey(msg.Ephemeral[:]) + handshake.mixHash(msg.Ephemeral[:]) - // encrypt identity key + // encrypt static key func() { var key [chacha20poly1305.KeySize]byte @@ -139,7 +145,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) }() - handshake.addToHash(msg.Static[:]) + handshake.mixHash(msg.Static[:]) // encrypt timestamp @@ -154,22 +160,22 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) }() - handshake.addToHash(msg.Timestamp[:]) - handshake.state = HandshakeInitialCreated + handshake.mixHash(msg.Timestamp[:]) + handshake.state = HandshakeInitiationCreated return &msg, nil } -func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { - if msg.Type != MessageInitalType { - panic(errors.New("bug: invalid inital message type")) +func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { + if msg.Type != MessageInitiationType { + return nil } - hash := addToHash(InitalHash, device.publicKey[:]) - hash = addToHash(hash, msg.Ephemeral[:]) - chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:]) + hash := mixHash(InitalHash, device.publicKey[:]) + hash = mixHash(hash, msg.Ephemeral[:]) + chainKey := mixKey(InitalChainKey, msg.Ephemeral[:]) - // decrypt identity key + // decrypt static key var err error var peerPK NoisePublicKey @@ -183,7 +189,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { if err != nil { return nil } - hash = addToHash(hash, msg.Static[:]) + hash = mixHash(hash, msg.Static[:]) // find peer @@ -210,7 +216,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { if err != nil { return nil } - hash = addToHash(hash, msg.Timestamp[:]) + hash = mixHash(hash, msg.Timestamp[:]) // check for replay attack @@ -218,7 +224,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { return nil } - // check for flood attack + // TODO: check for flood attack // update handshake state @@ -227,7 +233,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { handshake.remoteIndex = msg.Sender handshake.remoteEphemeral = msg.Ephemeral handshake.lastTimestamp = timestamp - handshake.state = HandshakeInitialConsumed + handshake.state = HandshakeInitiationConsumed return peer } @@ -236,8 +242,8 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mutex.Lock() defer handshake.mutex.Unlock() - if handshake.state != HandshakeInitialConsumed { - panic(errors.New("bug: handshake initation must be consumed first")) + if handshake.state != HandshakeInitiationConsumed { + return nil, errors.New("handshake initation must be consumed first") } // assign index @@ -260,13 +266,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error return nil, err } msg.Ephemeral = handshake.localEphemeral.publicKey() - handshake.addToHash(msg.Ephemeral[:]) + handshake.mixHash(msg.Ephemeral[:]) func() { ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) - handshake.addToChainKey(ss[:]) + handshake.mixKey(ss[:]) ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.addToChainKey(ss[:]) + handshake.mixKey(ss[:]) }() // add preshared key (psk) @@ -274,12 +280,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:]) - handshake.addToHash(tau[:]) + handshake.mixHash(tau[:]) func() { aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.addToHash(msg.Empty[:]) + handshake.mixHash(msg.Empty[:]) }() handshake.state = HandshakeResponseCreated @@ -288,7 +294,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { if msg.Type != MessageResponseType { - panic(errors.New("bug: invalid message type")) + return nil } // lookup handshake by reciever @@ -300,20 +306,20 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() - if handshake.state != HandshakeInitialCreated { + if handshake.state != HandshakeInitiationCreated { return nil } // finish 3-way DH - hash := addToHash(handshake.hash, msg.Ephemeral[:]) + hash := mixHash(handshake.hash, msg.Ephemeral[:]) chainKey := handshake.chainKey func() { ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - chainKey = addToChainKey(chainKey, ss[:]) + chainKey = mixKey(chainKey, ss[:]) ss = device.privateKey.sharedSecret(msg.Ephemeral) - chainKey = addToChainKey(chainKey, ss[:]) + chainKey = mixKey(chainKey, ss[:]) }() // add preshared key (psk) @@ -321,7 +327,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) - hash = addToHash(hash, tau[:]) + hash = mixHash(hash, tau[:]) // authenticate @@ -330,7 +336,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { if err != nil { return nil } - hash = addToHash(hash, msg.Empty[:]) + hash = mixHash(hash, msg.Empty[:]) // update handshake state @@ -368,7 +374,11 @@ func (peer *Peer) NewKeyPair() *KeyPair { keyPair.sendNonce = 0 keyPair.recvNonce = 0 - peer.handshake.state = HandshakeReset + // zero handshake + + handshake.chainKey = [blake2s.Size]byte{} + handshake.localEphemeral = NoisePrivateKey{} + peer.handshake.state = HandshakeZeroed return &keyPair } diff --git a/src/noise_test.go b/src/noise_test.go index ddabf8e..8450c1c 100644 --- a/src/noise_test.go +++ b/src/noise_test.go @@ -67,13 +67,13 @@ func TestNoiseHandshake(t *testing.T) { t.Log("exchange initiation message") - msg1, err := dev1.CreateMessageInitial(peer2) + msg1, err := dev1.CreateMessageInitiation(peer2) assertNil(t, err) packet := make([]byte, 0, 256) writer := bytes.NewBuffer(packet) err = binary.Write(writer, binary.LittleEndian, msg1) - peer := dev2.ConsumeMessageInitial(msg1) + peer := dev2.ConsumeMessageInitiation(msg1) if peer == nil { t.Fatal("handshake failed at initiation message") } diff --git a/src/peer.go b/src/peer.go index f6eb555..42b9e8d 100644 --- a/src/peer.go +++ b/src/peer.go @@ -1,39 +1,64 @@ package main import ( + "errors" + "golang.org/x/crypto/blake2s" "net" "sync" "time" ) +const ( + OutboundQueueSize = 64 +) + type Peer struct { mutex sync.RWMutex endpointIP net.IP // endpointPort uint16 // persistentKeepaliveInterval time.Duration // 0 = disabled + keyPairs KeyPairs handshake Handshake device *Device + macKey [blake2s.Size]byte // Hash(Label-Mac1 || publicKey) + cookie []byte // cookie + cookieExpire time.Time + queueInbound chan []byte + queueOutbound chan *OutboundWorkQueueElement + queueOutboundRouting chan []byte } func (device *Device) NewPeer(pk NoisePublicKey) *Peer { var peer Peer + // create peer + + peer.mutex.Lock() + peer.device = device + peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize) + // map public key device.mutex.Lock() + _, ok := device.peers[pk] + if ok { + panic(errors.New("bug: adding existing peer")) + } device.peers[pk] = &peer device.mutex.Unlock() - // precompute + // precompute DH - peer.mutex.Lock() - peer.device = device - func(h *Handshake) { - h.mutex.Lock() - h.remoteStatic = pk - h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) - h.mutex.Unlock() - }(&peer.handshake) + handshake := &peer.handshake + handshake.mutex.Lock() + handshake.remoteStatic = pk + handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) + + // compute mac key + + peer.macKey = blake2s.Sum256(append([]byte(WGLabelMAC1[:]), handshake.remoteStatic[:]...)) + + handshake.mutex.Unlock() peer.mutex.Unlock() return &peer diff --git a/src/routing.go b/src/routing.go index 553df11..4189c25 100644 --- a/src/routing.go +++ b/src/routing.go @@ -2,7 +2,6 @@ package main import ( "errors" - "fmt" "net" "sync" ) @@ -52,25 +51,3 @@ func (table *RoutingTable) LookupIPv6(address []byte) *Peer { defer table.mutex.RUnlock() return table.IPv6.Lookup(address) } - -func OutgoingRoutingWorker(device *Device, queue chan []byte) { - for { - packet := <-queue - switch packet[0] >> 4 { - - case IPv4version: - dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer := device.routingTable.LookupIPv4(dst) - fmt.Println("IPv4", peer) - - case IPv6version: - dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer := device.routingTable.LookupIPv6(dst) - fmt.Println("IPv6", peer) - - default: - // todo: log - fmt.Println("Unknown IP version") - } - } -} diff --git a/src/send.go b/src/send.go new file mode 100644 index 0000000..9790320 --- /dev/null +++ b/src/send.go @@ -0,0 +1,154 @@ +package main + +import ( + "net" + "sync" + "sync/atomic" +) + +/* Handles outbound flow + * + * 1. TUN queue + * 2. Routing + * 3. Per peer queuing + * 4. (work queuing) + * + */ + +type OutboundWorkQueueElement struct { + wg sync.WaitGroup + packet []byte + nonce uint64 + keyPair *KeyPair +} + +func (device *Device) SendPacket(packet []byte) { + + // lookup peer + + var peer *Peer + switch packet[0] >> 4 { + case IPv4version: + dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.routingTable.LookupIPv4(dst) + + case IPv6version: + dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.routingTable.LookupIPv6(dst) + + default: + device.logger.Println("unknown IP version") + return + } + + if peer == nil { + return + } + + // insert into peer queue + + for { + select { + case peer.queueOutboundRouting <- packet: + default: + select { + case <-peer.queueOutboundRouting: + default: + } + continue + } + break + } +} + +/* Go routine + * + * + * 1. waits for handshake. + * 2. assigns key pair & nonce + * 3. inserts to working queue + * + * TODO: avoid dynamic allocation of work queue elements + */ +func (peer *Peer) ConsumeOutboundPackets() { + for { + // wait for key pair + keyPair := func() *KeyPair { + peer.keyPairs.mutex.RLock() + defer peer.keyPairs.mutex.RUnlock() + return peer.keyPairs.current + }() + if keyPair == nil { + if len(peer.queueOutboundRouting) > 0 { + // TODO: start handshake + <-peer.keyPairs.newKeyPair + } + continue + } + + // assign packets key pair + for { + select { + case <-peer.keyPairs.newKeyPair: + default: + case <-peer.keyPairs.newKeyPair: + case packet := <-peer.queueOutboundRouting: + + // create new work element + + work := new(OutboundWorkQueueElement) + work.wg.Add(1) + work.keyPair = keyPair + work.packet = packet + work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1 + + peer.queueOutbound <- work + + // drop packets until there is room + + for { + select { + case peer.device.queueWorkOutbound <- work: + break + default: + drop := <-peer.device.queueWorkOutbound + drop.packet = nil + drop.wg.Done() + } + } + } + } + } +} + +func (peer *Peer) RoutineSequential() { + for work := range peer.queueOutbound { + work.wg.Wait() + if work.packet == nil { + continue + } + } +} + +func (device *Device) EncryptionWorker() { + for { + work := <-device.queueWorkOutbound + + func() { + defer work.wg.Done() + + // pad packet + padding := device.mtu - len(work.packet) + if padding < 0 { + work.packet = nil + return + } + for n := 0; n < padding; n += 1 { + work.packet = append(work.packet, 0) // TODO: gotta be a faster way + } + + // + + }() + } +} -- cgit v1.2.3