diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | Makefile (renamed from src/Makefile) | 0 | ||||
-rwxr-xr-x | build.cmd (renamed from src/build.cmd) | 0 | ||||
-rw-r--r-- | conn.go (renamed from src/conn.go) | 29 | ||||
-rw-r--r-- | conn_default.go (renamed from src/conn_default.go) | 0 | ||||
-rw-r--r-- | conn_linux.go (renamed from src/conn_linux.go) | 0 | ||||
-rw-r--r-- | constants.go (renamed from src/constants.go) | 0 | ||||
-rw-r--r-- | cookie.go (renamed from src/cookie.go) | 0 | ||||
-rw-r--r-- | cookie_test.go (renamed from src/cookie_test.go) | 0 | ||||
-rw-r--r-- | daemon_darwin.go (renamed from src/daemon_darwin.go) | 0 | ||||
-rw-r--r-- | daemon_linux.go (renamed from src/daemon_linux.go) | 0 | ||||
-rw-r--r-- | daemon_windows.go (renamed from src/daemon_windows.go) | 0 | ||||
-rw-r--r-- | device.go | 364 | ||||
-rw-r--r-- | helper_test.go (renamed from src/helper_test.go) | 8 | ||||
-rw-r--r-- | index.go (renamed from src/index.go) | 0 | ||||
-rw-r--r-- | ip.go (renamed from src/ip.go) | 0 | ||||
-rw-r--r-- | kdf_test.go (renamed from src/kdf_test.go) | 0 | ||||
-rw-r--r-- | keypair.go (renamed from src/keypair.go) | 4 | ||||
-rw-r--r-- | logger.go (renamed from src/logger.go) | 0 | ||||
-rw-r--r-- | main.go (renamed from src/main.go) | 0 | ||||
-rw-r--r-- | misc.go (renamed from src/misc.go) | 0 | ||||
-rw-r--r-- | noise_helpers.go (renamed from src/noise_helpers.go) | 7 | ||||
-rw-r--r-- | noise_protocol.go (renamed from src/noise_protocol.go) | 41 | ||||
-rw-r--r-- | noise_test.go (renamed from src/noise_test.go) | 4 | ||||
-rw-r--r-- | noise_types.go (renamed from src/noise_types.go) | 0 | ||||
-rw-r--r-- | peer.go (renamed from src/peer.go) | 226 | ||||
-rw-r--r-- | ratelimiter.go (renamed from src/ratelimiter.go) | 0 | ||||
-rw-r--r-- | ratelimiter_test.go (renamed from src/ratelimiter_test.go) | 0 | ||||
-rw-r--r-- | receive.go (renamed from src/receive.go) | 43 | ||||
-rw-r--r-- | replay.go (renamed from src/replay.go) | 0 | ||||
-rw-r--r-- | replay_test.go (renamed from src/replay_test.go) | 0 | ||||
-rw-r--r-- | routing.go (renamed from src/routing.go) | 0 | ||||
-rw-r--r-- | send.go (renamed from src/send.go) | 27 | ||||
-rw-r--r-- | signal.go (renamed from src/signal.go) | 0 | ||||
-rw-r--r-- | src/device.go | 221 | ||||
-rw-r--r-- | tai64.go (renamed from src/tai64.go) | 0 | ||||
-rwxr-xr-x | tests/netns.sh (renamed from src/tests/netns.sh) | 4 | ||||
-rw-r--r-- | timer.go (renamed from src/timer.go) | 6 | ||||
-rw-r--r-- | timers.go (renamed from src/timers.go) | 169 | ||||
-rw-r--r-- | trie.go (renamed from src/trie.go) | 0 | ||||
-rw-r--r-- | trie_rand_test.go (renamed from src/trie_rand_test.go) | 0 | ||||
-rw-r--r-- | trie_test.go (renamed from src/trie_test.go) | 0 | ||||
-rw-r--r-- | tun.go (renamed from src/tun.go) | 24 | ||||
-rw-r--r-- | tun_darwin.go (renamed from src/tun_darwin.go) | 0 | ||||
-rw-r--r-- | tun_linux.go (renamed from src/tun_linux.go) | 0 | ||||
-rw-r--r-- | tun_windows.go (renamed from src/tun_windows.go) | 0 | ||||
-rw-r--r-- | uapi.go (renamed from src/uapi.go) | 165 | ||||
-rw-r--r-- | uapi_darwin.go (renamed from src/uapi_darwin.go) | 0 | ||||
-rw-r--r-- | uapi_linux.go (renamed from src/uapi_linux.go) | 0 | ||||
-rw-r--r-- | uapi_windows.go (renamed from src/uapi_windows.go) | 0 | ||||
-rw-r--r-- | xchacha20.go (renamed from src/xchacha20.go) | 0 | ||||
-rw-r--r-- | xchacha20_test.go (renamed from src/xchacha20_test.go) | 0 |
52 files changed, 864 insertions, 479 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e460293 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +wireguard-go diff --git a/src/build.cmd b/build.cmd index 52cb883..52cb883 100755 --- a/src/build.cmd +++ b/build.cmd @@ -64,13 +64,13 @@ func unsafeCloseBind(device *Device) error { return err } -func updateBind(device *Device) error { - device.mutex.Lock() - defer device.mutex.Unlock() +func (device *Device) BindUpdate() error { - netc := &device.net - netc.mutex.Lock() - defer netc.mutex.Unlock() + device.net.mutex.Lock() + defer device.net.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() // close existing sockets @@ -78,17 +78,14 @@ func updateBind(device *Device) error { return err } - // assumption: netc.update WaitGroup should be exactly 1 - // open new sockets - if device.tun.isUp.Get() { - - device.log.Debug.Println("UDP bind updating") + if device.isUp.Get() { // bind to new port var err error + netc := &device.net netc.bind, netc.port, err = CreateBind(netc.port) if err != nil { netc.bind = nil @@ -104,15 +101,15 @@ func updateBind(device *Device) error { // clear cached source addresses - for _, peer := range device.peers { + for _, peer := range device.peers.keyMap { peer.mutex.Lock() + defer peer.mutex.Unlock() if peer.endpoint != nil { peer.endpoint.ClearSrc() } - peer.mutex.Unlock() } - // decrease waitgroup to 0 + // start receiving routines go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) @@ -123,11 +120,9 @@ func updateBind(device *Device) error { return nil } -func closeBind(device *Device) error { - device.mutex.Lock() +func (device *Device) BindClose() error { device.net.mutex.Lock() err := unsafeCloseBind(device) device.net.mutex.Unlock() - device.mutex.Unlock() return err } diff --git a/src/conn_default.go b/conn_default.go index 5b73c90..5b73c90 100644 --- a/src/conn_default.go +++ b/conn_default.go diff --git a/src/conn_linux.go b/conn_linux.go index cdba74f..cdba74f 100644 --- a/src/conn_linux.go +++ b/conn_linux.go diff --git a/src/constants.go b/constants.go index 71dd98e..71dd98e 100644 --- a/src/constants.go +++ b/constants.go diff --git a/src/cookie.go b/cookie.go index a13ad49..a13ad49 100644 --- a/src/cookie.go +++ b/cookie.go diff --git a/src/cookie_test.go b/cookie_test.go index d745fe7..d745fe7 100644 --- a/src/cookie_test.go +++ b/cookie_test.go diff --git a/src/daemon_darwin.go b/daemon_darwin.go index 913af0e..913af0e 100644 --- a/src/daemon_darwin.go +++ b/daemon_darwin.go diff --git a/src/daemon_linux.go b/daemon_linux.go index e1aaede..e1aaede 100644 --- a/src/daemon_linux.go +++ b/daemon_linux.go diff --git a/src/daemon_windows.go b/daemon_windows.go index d5ec1e8..d5ec1e8 100644 --- a/src/daemon_windows.go +++ b/daemon_windows.go diff --git a/device.go b/device.go new file mode 100644 index 0000000..22e0990 --- /dev/null +++ b/device.go @@ -0,0 +1,364 @@ +package main + +import ( + "runtime" + "sync" + "sync/atomic" + "time" +) + +type Device struct { + isUp AtomicBool // device is (going) up + isClosed AtomicBool // device is closed? (acting as guard) + log *Logger + + // synchronized resources (locks acquired in order) + + state struct { + mutex sync.Mutex + changing AtomicBool + current bool + } + + net struct { + mutex sync.RWMutex + bind Bind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + } + + noise struct { + mutex sync.RWMutex + privateKey NoisePrivateKey + publicKey NoisePublicKey + } + + routing struct { + mutex sync.RWMutex + table RoutingTable + } + + peers struct { + mutex sync.RWMutex + keyMap map[NoisePublicKey]*Peer + } + + // unprotected / "self-synchronising resources" + + indices IndexTable + mac CookieChecker + + rate struct { + underLoadUntil atomic.Value + limiter Ratelimiter + } + + pool struct { + messageBuffers sync.Pool + } + + queue struct { + encryption chan *QueueOutboundElement + decryption chan *QueueInboundElement + handshake chan QueueHandshakeElement + } + + signal struct { + stop Signal + } + + tun struct { + device TUNDevice + mtu int32 + } +} + +/* Converts the peer into a "zombie", which remains in the peer map, + * but processes no packets and does not exists in the routing table. + * + * Must hold: + * device.peers.mutex : exclusive lock + * device.routing : exclusive lock + */ +func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { + + // stop routing and processing of packets + + device.routing.table.RemovePeer(peer) + peer.Stop() + + // remove from peer map + + delete(device.peers.keyMap, key) +} + +func deviceUpdateState(device *Device) { + + // check if state already being updated (guard) + + if device.state.changing.Swap(true) { + return + } + + // compare to current state of device + + device.state.mutex.Lock() + + newIsUp := device.isUp.Get() + + if newIsUp == device.state.current { + device.state.changing.Set(false) + device.state.mutex.Unlock() + return + } + + // change state of device + + switch newIsUp { + case true: + if err := device.BindUpdate(); err != nil { + device.isUp.Set(false) + break + } + device.peers.mutex.Lock() + for _, peer := range device.peers.keyMap { + peer.Start() + } + device.peers.mutex.Unlock() + + case false: + device.BindClose() + device.peers.mutex.Lock() + for _, peer := range device.peers.keyMap { + peer.Stop() + } + device.peers.mutex.Unlock() + } + + // update state variables + + device.state.current = newIsUp + device.state.changing.Set(false) + device.state.mutex.Unlock() + + // check for state change in the mean time + + deviceUpdateState(device) +} + +func (device *Device) Up() { + + // closed device cannot be brought up + + if device.isClosed.Get() { + return + } + + device.state.mutex.Lock() + device.isUp.Set(true) + device.state.mutex.Unlock() + deviceUpdateState(device) +} + +func (device *Device) Down() { + device.state.mutex.Lock() + device.isUp.Set(false) + device.state.mutex.Unlock() + deviceUpdateState(device) +} + +func (device *Device) IsUnderLoad() bool { + + // check if currently under load + + now := time.Now() + underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + if underLoad { + device.rate.underLoadUntil.Store(now.Add(time.Second)) + return true + } + + // check if recently under load + + until := device.rate.underLoadUntil.Load().(time.Time) + return until.After(now) +} + +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { + + // lock required resources + + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for _, peer := range device.peers.keyMap { + peer.handshake.mutex.RLock() + defer peer.handshake.mutex.RUnlock() + } + + // remove peers with matching public keys + + publicKey := sk.publicKey() + for key, peer := range device.peers.keyMap { + if peer.handshake.remoteStatic.Equals(publicKey) { + unsafeRemovePeer(device, peer, key) + } + } + + // update key material + + device.noise.privateKey = sk + device.noise.publicKey = publicKey + device.mac.Init(publicKey) + + // do static-static DH pre-computations + + rmKey := device.noise.privateKey.IsZero() + + for key, peer := range device.peers.keyMap { + + hs := &peer.handshake + + if rmKey { + hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} + } else { + hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic) + } + + if isZero(hs.precomputedStaticStatic[:]) { + unsafeRemovePeer(device, peer, key) + } + } + + return nil +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + device.pool.messageBuffers.Put(msg) +} + +func NewDevice(tun TUNDevice, logger *Logger) *Device { + device := new(Device) + + device.isUp.Set(false) + device.isClosed.Set(false) + + device.log = logger + device.tun.device = tun + device.peers.keyMap = make(map[NoisePublicKey]*Peer) + + // initialize anti-DoS / anti-scanning features + + device.rate.limiter.Init() + device.rate.underLoadUntil.Store(time.Time{}) + + // initialize noise & crypt-key routine + + device.indices.Init() + device.routing.table.Reset() + + // setup buffer pool + + device.pool.messageBuffers = sync.Pool{ + New: func() interface{} { + return new([MaxMessageSize]byte) + }, + } + + // create queues + + device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) + device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) + device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) + + // prepare signals + + device.signal.stop = NewSignal() + + // prepare net + + device.net.port = 0 + device.net.bind = nil + + // start workers + + for i := 0; i < runtime.NumCPU(); i += 1 { + go device.RoutineEncryption() + go device.RoutineDecryption() + go device.RoutineHandshake() + } + + go device.RoutineReadFromTUN() + go device.RoutineTUNEventReader() + go device.rate.limiter.RoutineGarbageCollector(device.signal.stop) + + return device +} + +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { + device.peers.mutex.RLock() + defer device.peers.mutex.RUnlock() + + return device.peers.keyMap[pk] +} + +func (device *Device) RemovePeer(key NoisePublicKey) { + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // stop peer and remove from routing + + peer, ok := device.peers.keyMap[key] + if ok { + unsafeRemovePeer(device, peer, key) + } +} + +func (device *Device) RemoveAllPeers() { + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for key, peer := range device.peers.keyMap { + println("rm", peer.String()) + unsafeRemovePeer(device, peer, key) + } + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) +} + +func (device *Device) Close() { + device.log.Info.Println("Device closing") + if device.isClosed.Swap(true) { + return + } + device.signal.stop.Broadcast() + device.tun.device.Close() + device.BindClose() + device.isUp.Set(false) + device.RemoveAllPeers() + device.log.Info.Println("Interface closed") +} + +func (device *Device) Wait() chan struct{} { + return device.signal.stop.Wait() +} diff --git a/src/helper_test.go b/helper_test.go index 8548121..41e6b72 100644 --- a/src/helper_test.go +++ b/helper_test.go @@ -28,8 +28,8 @@ func (tun *DummyTUN) MTU() (int, error) { return tun.mtu, nil } -func (tun *DummyTUN) Write(d []byte) (int, error) { - tun.packets <- d +func (tun *DummyTUN) Write(d []byte, offset int) (int, error) { + tun.packets <- d[offset:] return len(d), nil } @@ -41,9 +41,9 @@ func (tun *DummyTUN) Events() chan TUNEvent { return tun.events } -func (tun *DummyTUN) Read(d []byte) (int, error) { +func (tun *DummyTUN) Read(d []byte, offset int) (int, error) { t := <-tun.packets - copy(d, t) + copy(d[offset:], t) return len(t), nil } diff --git a/src/kdf_test.go b/kdf_test.go index a89dacc..a89dacc 100644 --- a/src/kdf_test.go +++ b/kdf_test.go diff --git a/src/keypair.go b/keypair.go index 7e5297b..283cb92 100644 --- a/src/keypair.go +++ b/keypair.go @@ -38,5 +38,7 @@ func (kp *KeyPairs) Current() *KeyPair { } func (device *Device) DeleteKeyPair(key *KeyPair) { - device.indices.Delete(key.localIndex) + if key != nil { + device.indices.Delete(key.localIndex) + } } diff --git a/src/logger.go b/logger.go index 0872ef9..0872ef9 100644 --- a/src/logger.go +++ b/logger.go diff --git a/src/noise_helpers.go b/noise_helpers.go index 24302c0..1e2de5f 100644 --- a/src/noise_helpers.go +++ b/noise_helpers.go @@ -3,6 +3,7 @@ package main import ( "crypto/hmac" "crypto/rand" + "crypto/subtle" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/curve25519" "hash" @@ -58,11 +59,11 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { } func isZero(val []byte) bool { - var acc byte + acc := 1 for _, b := range val { - acc |= b + acc &= subtle.ConstantTimeByteEq(b, 0) } - return acc == 0 + return acc == 1 } func setZero(arr []byte) { diff --git a/src/noise_protocol.go b/noise_protocol.go index 2f9e1d5..c9713c0 100644 --- a/src/noise_protocol.go +++ b/noise_protocol.go @@ -121,6 +121,15 @@ func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { hsh.Reset() } +func (h *Handshake) Clear() { + setZero(h.localEphemeral[:]) + setZero(h.remoteEphemeral[:]) + setZero(h.chainKey[:]) + setZero(h.hash[:]) + h.localIndex = 0 + h.state = HandshakeZeroed +} + func (h *Handshake) mixHash(data []byte) { mixHash(&h.hash, &h.hash, data) } @@ -137,6 +146,10 @@ func init() { } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() @@ -187,7 +200,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e ss[:], ) aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:]) }() handshake.mixHash(msg.Static[:]) @@ -212,16 +225,19 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e } func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { - if msg.Type != MessageInitiationType { - return nil - } - var ( hash [blake2s.Size]byte chainKey [blake2s.Size]byte ) - mixHash(&hash, &InitialHash, device.publicKey[:]) + if msg.Type != MessageInitiationType { + return nil + } + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + mixHash(&hash, &InitialHash, device.noise.publicKey[:]) mixHash(&hash, &hash, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) @@ -231,7 +247,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { var peerPK NoisePublicKey func() { var key [chacha20poly1305.KeySize]byte - ss := device.privateKey.sharedSecret(msg.Ephemeral) + ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) KDF2(&chainKey, &key, chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) @@ -386,7 +402,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { ok := func() bool { - // read lock handshake + // lock handshake state handshake.mutex.RLock() defer handshake.mutex.RUnlock() @@ -395,6 +411,11 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { return false } + // lock private key for reading + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + // finish 3-way DH mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) @@ -407,7 +428,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { }() func() { - ss := device.privateKey.sharedSecret(msg.Ephemeral) + ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) }() @@ -425,7 +446,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { ) mixHash(&hash, &hash, tau[:]) - // authenticate + // authenticate transcript aead, _ := chacha20poly1305.New(key[:]) _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) diff --git a/src/noise_test.go b/noise_test.go index 0d7f0e9..5e9d44b 100644 --- a/src/noise_test.go +++ b/noise_test.go @@ -31,8 +31,8 @@ func TestNoiseHandshake(t *testing.T) { defer dev1.Close() defer dev2.Close() - peer1, _ := dev2.NewPeer(dev1.privateKey.publicKey()) - peer2, _ := dev1.NewPeer(dev2.privateKey.publicKey()) + peer1, _ := dev2.NewPeer(dev1.noise.privateKey.publicKey()) + peer2, _ := dev1.NewPeer(dev2.noise.privateKey.publicKey()) assertEqual( t, diff --git a/src/noise_types.go b/noise_types.go index 1a944df..1a944df 100644 --- a/src/noise_types.go +++ b/noise_types.go @@ -8,25 +8,32 @@ import ( "time" ) +const ( + PeerRoutineNumber = 4 +) + type Peer struct { - id uint + isRunning AtomicBool mutex sync.RWMutex persistentKeepaliveInterval uint64 keyPairs KeyPairs handshake Handshake device *Device endpoint Endpoint - stats struct { + + stats struct { txBytes uint64 // bytes send to peer (endpoint) rxBytes uint64 // bytes received from peer lastHandshakeNano int64 // nano seconds since epoch } + time struct { mutex sync.RWMutex lastSend time.Time // last send message lastHandshake time.Time // last completed handshake nextKeepalive time.Time } + signal struct { newKeyPair Signal // size 1, new key pair was generated handshakeCompleted Signal // size 1, handshake completed @@ -34,30 +41,62 @@ type Peer struct { flushNonceQueue Signal // size 1, empty queued packets messageSend Signal // size 1, message was send to peer messageReceived Signal // size 1, authenticated message recv - stop Signal // size 0, stop all goroutines } + timer struct { + // state related to WireGuard timers - keepalivePersistent Timer // set for persistent keepalives - keepalivePassive Timer // set upon recieving messages - newHandshake Timer // begin a new handshake (stale) + keepalivePersistent Timer // set for persistent keep-alive + keepalivePassive Timer // set upon receiving messages zeroAllKeys Timer // zero all key material + handshakeNew Timer // begin a new handshake (stale) handshakeDeadline Timer // complete handshake timeout handshakeTimeout Timer // current handshake message timeout sendLastMinuteHandshake bool needAnotherKeepalive bool } + queue struct { nonce chan *QueueOutboundElement // nonce / pre-handshake queue outbound chan *QueueOutboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work } + + routines struct { + mutex sync.Mutex // held when stopping / starting routines + starting sync.WaitGroup // routines pending start + stopping sync.WaitGroup // routines pending stop + stop Signal // size 0, stop all go-routines in peer + } + mac CookieGenerator } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { + + if device.isClosed.Get() { + return nil, errors.New("Device closed") + } + + // lock resources + + device.state.mutex.Lock() + defer device.state.mutex.Unlock() + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // check if over limit + + if len(device.peers.keyMap) >= MaxPeers { + return nil, errors.New("Too many peers") + } + // create peer peer := new(Peer) @@ -66,66 +105,46 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.mac.Init(pk) peer.device = device + peer.isRunning.Set(false) + peer.timer.zeroAllKeys = NewTimer() peer.timer.keepalivePersistent = NewTimer() peer.timer.keepalivePassive = NewTimer() - peer.timer.newHandshake = NewTimer() - peer.timer.zeroAllKeys = NewTimer() + peer.timer.handshakeNew = NewTimer() peer.timer.handshakeDeadline = NewTimer() peer.timer.handshakeTimeout = NewTimer() - // assign id for debugging - - device.mutex.Lock() - peer.id = device.idCounter - device.idCounter += 1 - - // check if over limit - - if len(device.peers) >= MaxPeers { - return nil, errors.New("Too many peers") - } - // map public key - _, ok := device.peers[pk] + _, ok := device.peers.keyMap[pk] if ok { return nil, errors.New("Adding existing peer") } - device.peers[pk] = peer - device.mutex.Unlock() + device.peers.keyMap[pk] = peer - // precompute DH + // pre-compute DH handshake := &peer.handshake handshake.mutex.Lock() handshake.remoteStatic = pk - handshake.precomputedStaticStatic = - device.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk) handshake.mutex.Unlock() // reset endpoint peer.endpoint = nil - // prepare queuing - - peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) - // prepare signaling & routines - peer.signal.stop = NewSignal() - peer.signal.newKeyPair = NewSignal() - peer.signal.handshakeBegin = NewSignal() - peer.signal.handshakeCompleted = NewSignal() - peer.signal.flushNonceQueue = NewSignal() + peer.routines.mutex.Lock() + peer.routines.stop = NewSignal() + peer.routines.mutex.Unlock() - go peer.RoutineNonce() - go peer.RoutineTimerHandler() - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() + // start peer + + if peer.device.isUp.Get() { + peer.Start() + } return peer, nil } @@ -133,32 +152,143 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { func (peer *Peer) SendBuffer(buffer []byte) error { peer.device.net.mutex.RLock() defer peer.device.net.mutex.RUnlock() + + if peer.device.net.bind == nil { + return errors.New("No bind") + } + peer.mutex.RLock() defer peer.mutex.RUnlock() + if peer.endpoint == nil { return errors.New("No known endpoint for peer") } + return peer.device.net.bind.Send(buffer, peer.endpoint) } -/* Returns a short string identification for logging +/* Returns a short string identifier for logging */ func (peer *Peer) String() string { if peer.endpoint == nil { return fmt.Sprintf( - "peer(%d unknown %s)", - peer.id, + "peer(unknown %s)", base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } return fmt.Sprintf( - "peer(%d %s %s)", - peer.id, + "peer(%s %s)", peer.endpoint.DstToString(), base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), ) } -func (peer *Peer) Close() { - peer.signal.stop.Broadcast() +func (peer *Peer) Start() { + + // should never start a peer on a closed device + + if peer.device.isClosed.Get() { + return + } + + // prevent simultaneous start/stop operations + + peer.routines.mutex.Lock() + defer peer.routines.mutex.Unlock() + + if peer.isRunning.Get() { + return + } + + peer.device.log.Debug.Println("Starting:", peer.String()) + + // sanity check : these should be 0 + + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + + // prepare queues and signals + + peer.signal.newKeyPair = NewSignal() + peer.signal.handshakeBegin = NewSignal() + peer.signal.handshakeCompleted = NewSignal() + peer.signal.flushNonceQueue = NewSignal() + + peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + + peer.routines.stop = NewSignal() + peer.isRunning.Set(true) + + // wait for routines to start + + peer.routines.starting.Add(PeerRoutineNumber) + peer.routines.stopping.Add(PeerRoutineNumber) + + go peer.RoutineNonce() + go peer.RoutineTimerHandler() + go peer.RoutineSequentialSender() + go peer.RoutineSequentialReceiver() + + peer.routines.starting.Wait() + peer.isRunning.Set(true) +} + +func (peer *Peer) Stop() { + + // prevent simultaneous start/stop operations + + peer.routines.mutex.Lock() + defer peer.routines.mutex.Unlock() + + if !peer.isRunning.Swap(false) { + return + } + + device := peer.device + device.log.Debug.Println("Stopping:", peer.String()) + + // stop & wait for ongoing peer routines + + peer.routines.stop.Broadcast() + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + + // stop timers + + peer.timer.keepalivePersistent.Stop() + peer.timer.keepalivePassive.Stop() + peer.timer.zeroAllKeys.Stop() + peer.timer.handshakeNew.Stop() + peer.timer.handshakeDeadline.Stop() + peer.timer.handshakeTimeout.Stop() + + // close queues + + close(peer.queue.nonce) + close(peer.queue.outbound) + close(peer.queue.inbound) + + // clear key pairs + + kp := &peer.keyPairs + kp.mutex.Lock() + + device.DeleteKeyPair(kp.previous) + device.DeleteKeyPair(kp.current) + device.DeleteKeyPair(kp.next) + + kp.previous = nil + kp.current = nil + kp.next = nil + kp.mutex.Unlock() + + // clear handshake state + + hs := &peer.handshake + hs.mutex.Lock() + device.indices.Delete(hs.localIndex) + hs.Clear() + hs.mutex.Unlock() } diff --git a/src/ratelimiter.go b/ratelimiter.go index 6e5f005..6e5f005 100644 --- a/src/ratelimiter.go +++ b/ratelimiter.go diff --git a/src/ratelimiter_test.go b/ratelimiter_test.go index 13b6a23..13b6a23 100644 --- a/src/ratelimiter_test.go +++ b/ratelimiter_test.go diff --git a/src/receive.go b/receive.go index dbd2813..1f44df2 100644 --- a/src/receive.go +++ b/receive.go @@ -123,7 +123,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { case ipv6.Version: size, endpoint, err = bind.ReceiveIPv6(buffer[:]) default: - return + panic("invalid IP version") } if err != nil { @@ -184,9 +184,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { // add to decryption queues - device.addToDecryptionQueue(device.queue.decryption, elem) - device.addToInboundQueue(peer.queue.inbound, elem) - buffer = device.GetMessageBuffer() + if peer.isRunning.Get() { + device.addToDecryptionQueue(device.queue.decryption, elem) + device.addToInboundQueue(peer.queue.inbound, elem) + buffer = device.GetMessageBuffer() + } continue @@ -308,13 +310,20 @@ func (device *Device) RoutineHandshake() { return } - // lookup peer and consume response + // lookup peer from index entry := device.indices.Lookup(reply.Receiver) + if entry.peer == nil { - return + continue + } + + // consume reply + + if peer := entry.peer; peer.isRunning.Get() { + peer.mac.ConsumeReply(&reply) } - entry.peer.mac.ConsumeReply(&reply) + continue case MessageInitiationType, MessageResponseType: @@ -323,7 +332,7 @@ func (device *Device) RoutineHandshake() { if !device.mac.CheckMAC1(elem.packet) { logDebug.Println("Received packet with invalid mac1") - return + continue } // endpoints destination address is the source of the datagram @@ -347,7 +356,7 @@ func (device *Device) RoutineHandshake() { reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) if err != nil { logError.Println("Failed to create cookie reply:", err) - return + continue } // marshal and send reply @@ -363,7 +372,7 @@ func (device *Device) RoutineHandshake() { // check ratelimiter - if !device.ratelimiter.Allow(elem.endpoint.DstIP()) { + if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { continue } } @@ -486,19 +495,23 @@ func (device *Device) RoutineHandshake() { func (peer *Peer) RoutineSequentialReceiver() { + defer peer.routines.stopping.Done() + device := peer.device logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug - logDebug.Println("Routine, sequential receiver, started for peer", peer.id) + logDebug.Println("Routine, sequential receiver, started for peer", peer.String()) + + peer.routines.starting.Done() for { select { - case <-peer.signal.stop.Wait(): - logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) + case <-peer.routines.stop.Wait(): + logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String()) return case elem := <-peer.queue.inbound: @@ -572,7 +585,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // verify IPv4 source src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.routingTable.LookupIPv4(src) != peer { + if device.routing.table.LookupIPv4(src) != peer { logInfo.Println( "IPv4 packet with disallowed source address from", peer.String(), @@ -600,7 +613,7 @@ func (peer *Peer) RoutineSequentialReceiver() { // verify IPv6 source src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.routingTable.LookupIPv6(src) != peer { + if device.routing.table.LookupIPv6(src) != peer { logInfo.Println( "IPv6 packet with disallowed source address from", peer.String(), diff --git a/src/replay.go b/replay.go index 5d42860..5d42860 100644 --- a/src/replay.go +++ b/replay.go diff --git a/src/replay_test.go b/replay_test.go index 228fce6..228fce6 100644 --- a/src/replay_test.go +++ b/replay_test.go diff --git a/src/routing.go b/routing.go index 2a2e237..2a2e237 100644 --- a/src/routing.go +++ b/routing.go @@ -151,14 +151,14 @@ func (device *Device) RoutineReadFromTUN() { continue } dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.routingTable.LookupIPv4(dst) + peer = device.routing.table.LookupIPv4(dst) case ipv6.Version: if len(elem.packet) < ipv6.HeaderLen { continue } dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.routingTable.LookupIPv6(dst) + peer = device.routing.table.LookupIPv6(dst) default: logDebug.Println("Received packet with unknown IP version") @@ -170,9 +170,11 @@ func (device *Device) RoutineReadFromTUN() { // insert into nonce/pre-handshake queue - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - addToOutboundQueue(peer.queue.nonce, elem) - elem = device.NewOutboundElement() + if peer.isRunning.Get() { + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + addToOutboundQueue(peer.queue.nonce, elem) + elem = device.NewOutboundElement() + } } } @@ -185,14 +187,18 @@ func (device *Device) RoutineReadFromTUN() { func (peer *Peer) RoutineNonce() { var keyPair *KeyPair + defer peer.routines.stopping.Done() + device := peer.device logDebug := device.log.Debug logDebug.Println("Routine, nonce worker, started for peer", peer.String()) + peer.routines.starting.Done() + for { NextPacket: select { - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): return case elem := <-peer.queue.nonce: @@ -217,7 +223,7 @@ func (peer *Peer) RoutineNonce() { logDebug.Println("Clearing queue for", peer.String()) peer.FlushNonceQueue() goto NextPacket - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): return } } @@ -309,15 +315,20 @@ func (device *Device) RoutineEncryption() { * The routine terminates then the outbound queue is closed. */ func (peer *Peer) RoutineSequentialSender() { + + defer peer.routines.stopping.Done() + device := peer.device logDebug := device.log.Debug logDebug.Println("Routine, sequential sender, started for", peer.String()) + peer.routines.starting.Done() + for { select { - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): logDebug.Println( "Routine, sequential sender, stopped for", peer.String()) return diff --git a/src/signal.go b/signal.go index 2cefad4..2cefad4 100644 --- a/src/signal.go +++ b/signal.go diff --git a/src/device.go b/src/device.go deleted file mode 100644 index a3461ad..0000000 --- a/src/device.go +++ /dev/null @@ -1,221 +0,0 @@ -package main - -import ( - "runtime" - "sync" - "sync/atomic" - "time" -) - -type Device struct { - closed AtomicBool // device is closed? (acting as guard) - log *Logger // collection of loggers for levels - idCounter uint // for assigning debug ids to peers - fwMark uint32 - tun struct { - device TUNDevice - isUp AtomicBool - mtu int32 - } - pool struct { - messageBuffers sync.Pool - } - net struct { - mutex sync.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) - } - mutex sync.RWMutex - privateKey NoisePrivateKey - publicKey NoisePublicKey - routingTable RoutingTable - indices IndexTable - queue struct { - encryption chan *QueueOutboundElement - decryption chan *QueueInboundElement - handshake chan QueueHandshakeElement - } - signal struct { - stop Signal - } - underLoadUntil atomic.Value - ratelimiter Ratelimiter - peers map[NoisePublicKey]*Peer - mac CookieChecker -} - -/* Warning: - * The caller must hold the device mutex (write lock) - */ -func removePeerUnsafe(device *Device, key NoisePublicKey) { - peer, ok := device.peers[key] - if !ok { - return - } - peer.mutex.Lock() - device.routingTable.RemovePeer(peer) - delete(device.peers, key) - peer.Close() -} - -func (device *Device) IsUnderLoad() bool { - - // check if currently under load - - now := time.Now() - underLoad := len(device.queue.handshake) >= UnderLoadQueueSize - if underLoad { - device.underLoadUntil.Store(now.Add(time.Second)) - return true - } - - // check if recently under load - - until := device.underLoadUntil.Load().(time.Time) - return until.After(now) -} - -func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { - device.mutex.Lock() - defer device.mutex.Unlock() - - // remove peers with matching public keys - - publicKey := sk.publicKey() - for key, peer := range device.peers { - h := &peer.handshake - h.mutex.RLock() - if h.remoteStatic.Equals(publicKey) { - removePeerUnsafe(device, key) - } - h.mutex.RUnlock() - } - - // update key material - - device.privateKey = sk - device.publicKey = publicKey - device.mac.Init(publicKey) - - // do DH precomputations - - rmKey := device.privateKey.IsZero() - - for key, peer := range device.peers { - h := &peer.handshake - h.mutex.Lock() - if rmKey { - h.precomputedStaticStatic = [NoisePublicKeySize]byte{} - } else { - h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) - if isZero(h.precomputedStaticStatic[:]) { - removePeerUnsafe(device, key) - } - } - h.mutex.Unlock() - } - - return nil -} - -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) -} - -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - device.pool.messageBuffers.Put(msg) -} - -func NewDevice(tun TUNDevice, logger *Logger) *Device { - device := new(Device) - device.mutex.Lock() - defer device.mutex.Unlock() - - device.log = logger - device.peers = make(map[NoisePublicKey]*Peer) - device.tun.device = tun - device.tun.isUp.Set(false) - - device.indices.Init() - device.ratelimiter.Init() - - device.routingTable.Reset() - device.underLoadUntil.Store(time.Time{}) - - // setup buffer pool - - device.pool.messageBuffers = sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - - // create queues - - device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) - device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) - device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) - - // prepare signals - - device.signal.stop = NewSignal() - - // prepare net - - device.net.port = 0 - device.net.bind = nil - - // start workers - - for i := 0; i < runtime.NumCPU(); i += 1 { - go device.RoutineEncryption() - go device.RoutineDecryption() - go device.RoutineHandshake() - } - - go device.RoutineReadFromTUN() - go device.RoutineTUNEventReader() - go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) - - return device -} - -func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { - device.mutex.RLock() - defer device.mutex.RUnlock() - return device.peers[pk] -} - -func (device *Device) RemovePeer(key NoisePublicKey) { - device.mutex.Lock() - defer device.mutex.Unlock() - removePeerUnsafe(device, key) -} - -func (device *Device) RemoveAllPeers() { - device.mutex.Lock() - defer device.mutex.Unlock() - - for key, peer := range device.peers { - peer.mutex.Lock() - delete(device.peers, key) - peer.Close() - peer.mutex.Unlock() - } -} - -func (device *Device) Close() { - if device.closed.Swap(true) { - return - } - device.log.Info.Println("Closing device") - device.RemoveAllPeers() - device.signal.stop.Broadcast() - device.tun.device.Close() - closeBind(device) -} - -func (device *Device) Wait() chan struct{} { - return device.signal.stop.Wait() -} diff --git a/src/tests/netns.sh b/tests/netns.sh index 02d428b..6c47a44 100755 --- a/src/tests/netns.sh +++ b/tests/netns.sh @@ -80,11 +80,11 @@ pp ip netns add $netns2 ip0 link set up dev lo # ip0 link add dev wg1 type wireguard -n0 $program wg1 +n0 $program -f wg1 & ip0 link set wg1 netns $netns1 # ip0 link add dev wg1 type wireguard -n0 $program wg2 +n0 $program -f wg2 & ip0 link set wg2 netns $netns2 key1="$(pp wg genkey)" @@ -43,12 +43,6 @@ func (t *Timer) Reset(dur time.Duration) { t.Start(dur) } -func (t *Timer) Push(dur time.Duration) { - if t.pending.Get() { - t.Reset(dur) - } -} - func (t *Timer) Wait() <-chan time.Time { return t.timer.C } diff --git a/src/timers.go b/timers.go index ee47393..7092688 100644 --- a/src/timers.go +++ b/timers.go @@ -8,6 +8,12 @@ import ( "time"
)
+/* NOTE:
+ * Notion of validity
+ *
+ *
+ */
+
/* Called when a new authenticated message has been send
*
*/
@@ -44,25 +50,25 @@ func (peer *Peer) KeepKeyFreshReceiving() { send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send {
// do a last minute attempt at initiating a new handshake
- peer.signal.handshakeBegin.Send()
peer.timer.sendLastMinuteHandshake = true
+ peer.signal.handshakeBegin.Send()
}
}
/* Queues a keep-alive if no packets are queued for peer
*/
func (peer *Peer) SendKeepAlive() bool {
+ if len(peer.queue.nonce) != 0 {
+ return false
+ }
elem := peer.device.NewOutboundElement()
elem.packet = nil
- if len(peer.queue.nonce) == 0 {
- select {
- case peer.queue.nonce <- elem:
- return true
- default:
- return false
- }
+ select {
+ case peer.queue.nonce <- elem:
+ return true
+ default:
+ return false
}
- return true
}
/* Event:
@@ -70,9 +76,7 @@ func (peer *Peer) SendKeepAlive() bool { */
func (peer *Peer) TimerDataSent() {
peer.timer.keepalivePassive.Stop()
- if peer.timer.newHandshake.Pending() {
- peer.timer.newHandshake.Reset(NewHandshakeTime)
- }
+ peer.timer.handshakeNew.Start(NewHandshakeTime)
}
/* Event:
@@ -91,7 +95,7 @@ func (peer *Peer) TimerDataReceived() { * Any (authenticated) packet received
*/
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
- peer.timer.newHandshake.Stop()
+ peer.timer.handshakeNew.Stop()
}
/* Event:
@@ -115,10 +119,6 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { * - First transport message under the "next" key
*/
func (peer *Peer) TimerHandshakeComplete() {
- atomic.StoreInt64(
- &peer.stats.lastHandshakeNano,
- time.Now().UnixNano(),
- )
peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
}
@@ -139,16 +139,86 @@ func (peer *Peer) TimerEphemeralKeyCreated() { peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
}
+/* Sends a new handshake initiation message to the peer (endpoint)
+ */
+func (peer *Peer) sendNewHandshake() error {
+
+ // temporarily disable the handshake complete signal
+
+ peer.signal.handshakeCompleted.Disable()
+
+ // create initiation message
+
+ msg, err := peer.device.CreateMessageInitiation(peer)
+ if err != nil {
+ return err
+ }
+
+ // marshal handshake message
+
+ var buff [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, msg)
+ packet := writer.Bytes()
+ peer.mac.AddMacs(packet)
+
+ // send to endpoint
+
+ peer.TimerAnyAuthenticatedPacketTraversal()
+
+ err = peer.SendBuffer(packet)
+ if err == nil {
+ peer.signal.handshakeCompleted.Enable()
+ }
+
+ // set timeout
+
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
+
+ peer.timer.keepalivePassive.Stop()
+ peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
+
+ return err
+}
+
func (peer *Peer) RoutineTimerHandler() {
+
+ defer peer.routines.stopping.Done()
+
device := peer.device
logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String())
+ // reset all timers
+
+ peer.timer.keepalivePassive.Stop()
+ peer.timer.handshakeDeadline.Stop()
+ peer.timer.handshakeTimeout.Stop()
+ peer.timer.handshakeNew.Stop()
+ peer.timer.zeroAllKeys.Stop()
+
+ interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
+ if interval > 0 {
+ duration := time.Duration(interval) * time.Second
+ peer.timer.keepalivePersistent.Reset(duration)
+ }
+
+ // signal synchronised setup complete
+
+ peer.routines.starting.Done()
+
+ // handle timer events
+
for {
select {
+ /* stopping */
+
+ case <-peer.routines.stop.Wait():
+ return
+
/* timers */
// keep-alive
@@ -158,6 +228,7 @@ func (peer *Peer) RoutineTimerHandler() { interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
logDebug.Println("Sending keep-alive to", peer.String())
+ peer.timer.keepalivePassive.Stop()
peer.SendKeepAlive()
}
@@ -168,8 +239,8 @@ func (peer *Peer) RoutineTimerHandler() { peer.SendKeepAlive()
if peer.timer.needAnotherKeepalive {
- peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
peer.timer.needAnotherKeepalive = false
+ peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
}
// clear key material timer
@@ -203,17 +274,12 @@ func (peer *Peer) RoutineTimerHandler() { // zero out handshake
device.indices.Delete(hs.localIndex)
-
- hs.localIndex = 0
- setZero(hs.localEphemeral[:])
- setZero(hs.remoteEphemeral[:])
- setZero(hs.chainKey[:])
- setZero(hs.hash[:])
+ hs.Clear()
hs.mutex.Unlock()
// handshake timers
- case <-peer.timer.newHandshake.Wait():
+ case <-peer.timer.handshakeNew.Wait():
logInfo.Println("Retrying handshake with", peer.String())
peer.signal.handshakeBegin.Send()
@@ -232,7 +298,7 @@ func (peer *Peer) RoutineTimerHandler() { err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
- "Failed to send handshake to peer:", peer.String())
+ "Failed to send handshake to peer:", peer.String(), "(", err, ")")
}
case <-peer.timer.handshakeDeadline.Wait():
@@ -248,9 +314,6 @@ func (peer *Peer) RoutineTimerHandler() { /* signals */
- case <-peer.signal.stop.Wait():
- return
-
case <-peer.signal.handshakeBegin.Wait():
peer.signal.handshakeBegin.Disable()
@@ -258,7 +321,7 @@ func (peer *Peer) RoutineTimerHandler() { err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
- "Failed to send handshake to peer:", peer.String())
+ "Failed to send handshake to peer:", peer.String(), "(", err, ")")
}
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
@@ -268,48 +331,16 @@ func (peer *Peer) RoutineTimerHandler() { logInfo.Println(
"Handshake completed for:", peer.String())
+ atomic.StoreInt64(
+ &peer.stats.lastHandshakeNano,
+ time.Now().UnixNano(),
+ )
+
peer.timer.handshakeTimeout.Stop()
peer.timer.handshakeDeadline.Stop()
peer.signal.handshakeBegin.Enable()
- }
- }
-}
-
-/* Sends a new handshake initiation message to the peer (endpoint)
- */
-func (peer *Peer) sendNewHandshake() error {
-
- // temporarily disable the handshake complete signal
-
- peer.signal.handshakeCompleted.Disable()
- // create initiation message
-
- msg, err := peer.device.CreateMessageInitiation(peer)
- if err != nil {
- return err
- }
-
- // marshal handshake message
-
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
- binary.Write(writer, binary.LittleEndian, msg)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
-
- // send to endpoint
-
- err = peer.SendBuffer(packet)
- if err == nil {
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.signal.handshakeCompleted.Enable()
+ peer.timer.sendLastMinuteHandshake = false
+ }
}
-
- // set timeout
-
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
- peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
-
- return err
}
diff --git a/src/trie_rand_test.go b/trie_rand_test.go index 840d269..840d269 100644 --- a/src/trie_rand_test.go +++ b/trie_rand_test.go diff --git a/src/trie_test.go b/trie_test.go index 9d53df3..9d53df3 100644 --- a/src/trie_test.go +++ b/trie_test.go @@ -45,26 +45,14 @@ func (device *Device) RoutineTUNEventReader() { } } - if event&TUNEventUp != 0 { - if !device.tun.isUp.Get() { - // begin listening for incomming datagrams - logInfo.Println("Interface set up") - device.tun.isUp.Set(true) - if err := updateBind(device); err != nil { - logInfo.Println("Failed to bind UDP socket:", err) - } - } + if event&TUNEventUp != 0 && !device.isUp.Get() { + logInfo.Println("Interface set up") + device.Up() } - if event&TUNEventDown != 0 { - if device.tun.isUp.Get() { - // stop listening for incomming datagrams - logInfo.Println("Interface set down") - device.tun.isUp.Set(false) - if err := closeBind(device); err != nil { - logInfo.Println("Failed to close UDP socket:", err) - } - } + if event&TUNEventDown != 0 && device.isUp.Get() { + logInfo.Println("Interface set down") + device.Down() } } } diff --git a/src/tun_darwin.go b/tun_darwin.go index 87f6af6..87f6af6 100644 --- a/src/tun_darwin.go +++ b/tun_darwin.go diff --git a/src/tun_linux.go b/tun_linux.go index daa2462..daa2462 100644 --- a/src/tun_linux.go +++ b/tun_linux.go diff --git a/src/tun_windows.go b/tun_windows.go index 0711032..0711032 100644 --- a/src/tun_windows.go +++ b/tun_windows.go @@ -25,32 +25,51 @@ func (s *IPCError) ErrorCode() int64 { func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { - // create lines + device.log.Debug.Println("UAPI: Processing get operation") - device.mutex.RLock() - device.net.mutex.RLock() + // create lines lines := make([]string, 0, 100) send := func(line string) { lines = append(lines, line) } - if !device.privateKey.IsZero() { - send("private_key=" + device.privateKey.ToHex()) - } + func() { - if device.net.port != 0 { - send(fmt.Sprintf("listen_port=%d", device.net.port)) - } + // lock required resources - if device.net.fwmark != 0 { - send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) - } + device.net.mutex.RLock() + defer device.net.mutex.RUnlock() + + device.noise.mutex.RLock() + defer device.noise.mutex.RUnlock() + + device.routing.mutex.RLock() + defer device.routing.mutex.RUnlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // serialize device related values + + if !device.noise.privateKey.IsZero() { + send("private_key=" + device.noise.privateKey.ToHex()) + } + + if device.net.port != 0 { + send(fmt.Sprintf("listen_port=%d", device.net.port)) + } + + if device.net.fwmark != 0 { + send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) + } - for _, peer := range device.peers { - func() { + // serialize each peer state + + for _, peer := range device.peers.keyMap { peer.mutex.RLock() defer peer.mutex.RUnlock() + send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex()) if peer.endpoint != nil { @@ -69,16 +88,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { atomic.LoadUint64(&peer.persistentKeepaliveInterval), )) - for _, ip := range device.routingTable.AllowedIPs(peer) { + for _, ip := range device.routing.table.AllowedIPs(peer) { send("allowed_ip=" + ip.String()) } - }() - } - device.net.mutex.RUnlock() - device.mutex.RUnlock() + } + }() - // send lines + // send lines (does not require resource locks) for _, line := range lines { _, err := socket.WriteString(line + "\n") @@ -94,7 +111,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { scanner := bufio.NewScanner(socket) - logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug @@ -130,16 +146,28 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to set private_key:", err) return &IPCError{Code: ipcErrorInvalid} } + logDebug.Println("UAPI: Updating device private key") device.SetPrivateKey(sk) case "listen_port": + + // parse port number + port, err := strconv.ParseUint(value, 10, 16) if err != nil { logError.Println("Failed to parse listen_port:", err) return &IPCError{Code: ipcErrorInvalid} } + + // update port and rebind + + logDebug.Println("UAPI: Updating listen port") + + device.net.mutex.Lock() device.net.port = uint16(port) - if err := updateBind(device); err != nil { + device.net.mutex.Unlock() + + if err := device.BindUpdate(); err != nil { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorPortInUse} } @@ -161,15 +189,20 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalid} } + logDebug.Println("UAPI: Updating fwmark") + device.net.mutex.Lock() device.net.fwmark = uint32(fwmark) - if err := device.net.bind.SetMark(fwmark); err != nil { + device.net.mutex.Unlock() + + if err := device.BindUpdate(); err != nil { logError.Println("Failed to update fwmark:", err) + return &IPCError{Code: ipcErrorPortInUse} } - device.net.mutex.Unlock() case "public_key": // switch to peer configuration + logDebug.Println("UAPI: Transition to peer configuration") deviceConfig = false case "replace_peers": @@ -177,6 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to set replace_peers, invalid value:", value) return &IPCError{Code: ipcErrorInvalid} } + logDebug.Println("UAPI: Removing all peers") device.RemoveAllPeers() default: @@ -192,43 +226,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { switch key { case "public_key": - var pubKey NoisePublicKey - err := pubKey.FromHex(value) + var publicKey NoisePublicKey + err := publicKey.FromHex(value) if err != nil { logError.Println("Failed to get peer by public_key:", err) return &IPCError{Code: ipcErrorInvalid} } - // check if public key of peer equal to device - - device.mutex.RLock() - if device.publicKey.Equals(pubKey) { + // ignore peer with public key of device - // create dummy instance (not added to device) + device.noise.mutex.RLock() + equals := device.noise.publicKey.Equals(publicKey) + device.noise.mutex.RUnlock() + if equals { peer = &Peer{} dummy = true - device.mutex.RUnlock() - logInfo.Println("Ignoring peer with public key of device") + } - } else { + // find peer referenced - // find peer referenced + peer = device.LookupPeer(publicKey) - peer, _ = device.peers[pubKey] - device.mutex.RUnlock() - if peer == nil { - peer, err = device.NewPeer(pubKey) - if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{Code: ipcErrorInvalid} - } + if peer == nil { + peer, err = device.NewPeer(publicKey) + if err != nil { + logError.Println("Failed to create new peer:", err) + return &IPCError{Code: ipcErrorInvalid} } - peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) - dummy = false - + logDebug.Println("UAPI: Created new peer:", peer.String()) } + peer.mutex.Lock() + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) + peer.mutex.Unlock() + case "remove": // remove currently selected peer from device @@ -238,7 +270,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalid} } if !dummy { - logDebug.Println("Removing", peer.String()) + logDebug.Println("UAPI: Removing peer:", peer.String()) device.RemovePeer(peer.handshake.remoteStatic) } peer = &Peer{} @@ -248,9 +280,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // update PSK - peer.mutex.Lock() + logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String()) + + peer.handshake.mutex.Lock() err := peer.handshake.presharedKey.FromHex(value) - peer.mutex.Unlock() + peer.handshake.mutex.Unlock() + if err != nil { logError.Println("Failed to set preshared_key:", err) return &IPCError{Code: ipcErrorInvalid} @@ -260,6 +295,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // set endpoint destination + logDebug.Println("UAPI: Updating endpoint for peer:", peer.String()) + err := func() error { peer.mutex.Lock() defer peer.mutex.Unlock() @@ -281,6 +318,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // update keep-alive interval + logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String()) + secs, err := strconv.ParseUint(value, 10, 16) if err != nil { logError.Println("Failed to set persistent_keepalive_interval:", err) @@ -299,31 +338,47 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to get tun device status:", err) return &IPCError{Code: ipcErrorIO} } - if device.tun.isUp.Get() && !dummy { + if device.isUp.Get() && !dummy { peer.SendKeepAlive() } } case "replace_allowed_ips": + + logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String()) + if value != "true" { logError.Println("Failed to set replace_allowed_ips, invalid value:", value) return &IPCError{Code: ipcErrorInvalid} } - if !dummy { - device.routingTable.RemovePeer(peer) + + if dummy { + continue } + device.routing.mutex.Lock() + device.routing.table.RemovePeer(peer) + device.routing.mutex.Unlock() + case "allowed_ip": + + logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String()) + _, network, err := net.ParseCIDR(value) if err != nil { logError.Println("Failed to set allowed_ip:", err) return &IPCError{Code: ipcErrorInvalid} } - ones, _ := network.Mask.Size() - if !dummy { - device.routingTable.Insert(network.IP, uint(ones), peer) + + if dummy { + continue } + ones, _ := network.Mask.Size() + device.routing.mutex.Lock() + device.routing.table.Insert(network.IP, uint(ones), peer) + device.routing.mutex.Unlock() + default: logError.Println("Invalid UAPI key (peer configuration):", key) return &IPCError{Code: ipcErrorInvalid} diff --git a/src/uapi_darwin.go b/uapi_darwin.go index 63d4d8d..63d4d8d 100644 --- a/src/uapi_darwin.go +++ b/uapi_darwin.go diff --git a/src/uapi_linux.go b/uapi_linux.go index f97a18a..f97a18a 100644 --- a/src/uapi_linux.go +++ b/uapi_linux.go diff --git a/src/uapi_windows.go b/uapi_windows.go index a4599a5..a4599a5 100644 --- a/src/uapi_windows.go +++ b/uapi_windows.go diff --git a/src/xchacha20.go b/xchacha20.go index 5d963e0..5d963e0 100644 --- a/src/xchacha20.go +++ b/xchacha20.go diff --git a/src/xchacha20_test.go b/xchacha20_test.go index 0f41cf8..0f41cf8 100644 --- a/src/xchacha20_test.go +++ b/xchacha20_test.go |