diff options
Diffstat (limited to '')
-rw-r--r-- | src/noise_protocol.go | 106 |
1 files changed, 58 insertions, 48 deletions
diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 7f26cf1..a16908a 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -9,9 +9,9 @@ import ( ) const ( - HandshakeReset = iota - HandshakeInitialCreated - HandshakeInitialConsumed + HandshakeZeroed = iota + HandshakeInitiationCreated + HandshakeInitiationConsumed HandshakeResponseCreated HandshakeResponseConsumed ) @@ -24,13 +24,19 @@ const ( ) const ( - MessageInitalType = 1 + MessageInitiationType = 1 MessageResponseType = 2 MessageCookieResponseType = 3 MessageTransportType = 4 ) -type MessageInital struct { +/* Type is an 8-bit field, followed by 3 nul bytes, + * by marshalling the messages in little-endian byteorder + * we can treat these as a 32-bit int + * + */ + +type MessageInitiation struct { Type uint32 Sender uint32 Ephemeral NoisePublicKey @@ -73,9 +79,9 @@ type Handshake struct { } var ( - ZeroNonce [chacha20poly1305.NonceSize]byte InitalChainKey [blake2s.Size]byte InitalHash [blake2s.Size]byte + ZeroNonce [chacha20poly1305.NonceSize]byte ) func init() { @@ -83,23 +89,23 @@ func init() { InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) } -func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { +func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { return KDF1(c[:], data) } -func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { +func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { return blake2s.Sum256(append(h[:], data...)) } -func (h *Handshake) addToHash(data []byte) { - h.hash = addToHash(h.hash, data) +func (h *Handshake) mixHash(data []byte) { + h.hash = mixHash(h.hash, data) } -func (h *Handshake) addToChainKey(data []byte) { - h.chainKey = addToChainKey(h.chainKey, data) +func (h *Handshake) mixKey(data []byte) { + h.chainKey = mixKey(h.chainKey, data) } -func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { +func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() @@ -108,7 +114,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { var err error handshake.chainKey = InitalChainKey - handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:]) + handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:]) handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err @@ -116,9 +122,9 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { // assign index - var msg MessageInital + var msg MessageInitiation - msg.Type = MessageInitalType + msg.Type = MessageInitiationType msg.Ephemeral = handshake.localEphemeral.publicKey() handshake.localIndex, err = device.indices.NewIndex(peer) @@ -127,10 +133,10 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { } msg.Sender = handshake.localIndex - handshake.addToChainKey(msg.Ephemeral[:]) - handshake.addToHash(msg.Ephemeral[:]) + handshake.mixKey(msg.Ephemeral[:]) + handshake.mixHash(msg.Ephemeral[:]) - // encrypt identity key + // encrypt static key func() { var key [chacha20poly1305.KeySize]byte @@ -139,7 +145,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) }() - handshake.addToHash(msg.Static[:]) + handshake.mixHash(msg.Static[:]) // encrypt timestamp @@ -154,22 +160,22 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) }() - handshake.addToHash(msg.Timestamp[:]) - handshake.state = HandshakeInitialCreated + handshake.mixHash(msg.Timestamp[:]) + handshake.state = HandshakeInitiationCreated return &msg, nil } -func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { - if msg.Type != MessageInitalType { - panic(errors.New("bug: invalid inital message type")) +func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { + if msg.Type != MessageInitiationType { + return nil } - hash := addToHash(InitalHash, device.publicKey[:]) - hash = addToHash(hash, msg.Ephemeral[:]) - chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:]) + hash := mixHash(InitalHash, device.publicKey[:]) + hash = mixHash(hash, msg.Ephemeral[:]) + chainKey := mixKey(InitalChainKey, msg.Ephemeral[:]) - // decrypt identity key + // decrypt static key var err error var peerPK NoisePublicKey @@ -183,7 +189,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { if err != nil { return nil } - hash = addToHash(hash, msg.Static[:]) + hash = mixHash(hash, msg.Static[:]) // find peer @@ -210,7 +216,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { if err != nil { return nil } - hash = addToHash(hash, msg.Timestamp[:]) + hash = mixHash(hash, msg.Timestamp[:]) // check for replay attack @@ -218,7 +224,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { return nil } - // check for flood attack + // TODO: check for flood attack // update handshake state @@ -227,7 +233,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { handshake.remoteIndex = msg.Sender handshake.remoteEphemeral = msg.Ephemeral handshake.lastTimestamp = timestamp - handshake.state = HandshakeInitialConsumed + handshake.state = HandshakeInitiationConsumed return peer } @@ -236,8 +242,8 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mutex.Lock() defer handshake.mutex.Unlock() - if handshake.state != HandshakeInitialConsumed { - panic(errors.New("bug: handshake initation must be consumed first")) + if handshake.state != HandshakeInitiationConsumed { + return nil, errors.New("handshake initation must be consumed first") } // assign index @@ -260,13 +266,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error return nil, err } msg.Ephemeral = handshake.localEphemeral.publicKey() - handshake.addToHash(msg.Ephemeral[:]) + handshake.mixHash(msg.Ephemeral[:]) func() { ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) - handshake.addToChainKey(ss[:]) + handshake.mixKey(ss[:]) ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.addToChainKey(ss[:]) + handshake.mixKey(ss[:]) }() // add preshared key (psk) @@ -274,12 +280,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:]) - handshake.addToHash(tau[:]) + handshake.mixHash(tau[:]) func() { aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.addToHash(msg.Empty[:]) + handshake.mixHash(msg.Empty[:]) }() handshake.state = HandshakeResponseCreated @@ -288,7 +294,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { if msg.Type != MessageResponseType { - panic(errors.New("bug: invalid message type")) + return nil } // lookup handshake by reciever @@ -300,20 +306,20 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() - if handshake.state != HandshakeInitialCreated { + if handshake.state != HandshakeInitiationCreated { return nil } // finish 3-way DH - hash := addToHash(handshake.hash, msg.Ephemeral[:]) + hash := mixHash(handshake.hash, msg.Ephemeral[:]) chainKey := handshake.chainKey func() { ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - chainKey = addToChainKey(chainKey, ss[:]) + chainKey = mixKey(chainKey, ss[:]) ss = device.privateKey.sharedSecret(msg.Ephemeral) - chainKey = addToChainKey(chainKey, ss[:]) + chainKey = mixKey(chainKey, ss[:]) }() // add preshared key (psk) @@ -321,7 +327,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) - hash = addToHash(hash, tau[:]) + hash = mixHash(hash, tau[:]) // authenticate @@ -330,7 +336,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { if err != nil { return nil } - hash = addToHash(hash, msg.Empty[:]) + hash = mixHash(hash, msg.Empty[:]) // update handshake state @@ -368,7 +374,11 @@ func (peer *Peer) NewKeyPair() *KeyPair { keyPair.sendNonce = 0 keyPair.recvNonce = 0 - peer.handshake.state = HandshakeReset + // zero handshake + + handshake.chainKey = [blake2s.Size]byte{} + handshake.localEphemeral = NoisePrivateKey{} + peer.handshake.state = HandshakeZeroed return &keyPair } |