aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--Makefile (renamed from src/Makefile)0
-rwxr-xr-xbuild.cmd (renamed from src/build.cmd)0
-rw-r--r--conn.go (renamed from src/conn.go)29
-rw-r--r--conn_default.go (renamed from src/conn_default.go)0
-rw-r--r--conn_linux.go (renamed from src/conn_linux.go)0
-rw-r--r--constants.go (renamed from src/constants.go)0
-rw-r--r--cookie.go (renamed from src/cookie.go)0
-rw-r--r--cookie_test.go (renamed from src/cookie_test.go)0
-rw-r--r--daemon_darwin.go (renamed from src/daemon_darwin.go)0
-rw-r--r--daemon_linux.go (renamed from src/daemon_linux.go)0
-rw-r--r--daemon_windows.go (renamed from src/daemon_windows.go)0
-rw-r--r--device.go364
-rw-r--r--helper_test.go (renamed from src/helper_test.go)8
-rw-r--r--index.go (renamed from src/index.go)0
-rw-r--r--ip.go (renamed from src/ip.go)0
-rw-r--r--kdf_test.go (renamed from src/kdf_test.go)0
-rw-r--r--keypair.go (renamed from src/keypair.go)4
-rw-r--r--logger.go (renamed from src/logger.go)0
-rw-r--r--main.go (renamed from src/main.go)0
-rw-r--r--misc.go (renamed from src/misc.go)0
-rw-r--r--noise_helpers.go (renamed from src/noise_helpers.go)7
-rw-r--r--noise_protocol.go (renamed from src/noise_protocol.go)41
-rw-r--r--noise_test.go (renamed from src/noise_test.go)4
-rw-r--r--noise_types.go (renamed from src/noise_types.go)0
-rw-r--r--peer.go (renamed from src/peer.go)226
-rw-r--r--ratelimiter.go (renamed from src/ratelimiter.go)0
-rw-r--r--ratelimiter_test.go (renamed from src/ratelimiter_test.go)0
-rw-r--r--receive.go (renamed from src/receive.go)43
-rw-r--r--replay.go (renamed from src/replay.go)0
-rw-r--r--replay_test.go (renamed from src/replay_test.go)0
-rw-r--r--routing.go (renamed from src/routing.go)0
-rw-r--r--send.go (renamed from src/send.go)27
-rw-r--r--signal.go (renamed from src/signal.go)0
-rw-r--r--src/device.go221
-rw-r--r--tai64.go (renamed from src/tai64.go)0
-rwxr-xr-xtests/netns.sh (renamed from src/tests/netns.sh)4
-rw-r--r--timer.go (renamed from src/timer.go)6
-rw-r--r--timers.go (renamed from src/timers.go)169
-rw-r--r--trie.go (renamed from src/trie.go)0
-rw-r--r--trie_rand_test.go (renamed from src/trie_rand_test.go)0
-rw-r--r--trie_test.go (renamed from src/trie_test.go)0
-rw-r--r--tun.go (renamed from src/tun.go)24
-rw-r--r--tun_darwin.go (renamed from src/tun_darwin.go)0
-rw-r--r--tun_linux.go (renamed from src/tun_linux.go)0
-rw-r--r--tun_windows.go (renamed from src/tun_windows.go)0
-rw-r--r--uapi.go (renamed from src/uapi.go)165
-rw-r--r--uapi_darwin.go (renamed from src/uapi_darwin.go)0
-rw-r--r--uapi_linux.go (renamed from src/uapi_linux.go)0
-rw-r--r--uapi_windows.go (renamed from src/uapi_windows.go)0
-rw-r--r--xchacha20.go (renamed from src/xchacha20.go)0
-rw-r--r--xchacha20_test.go (renamed from src/xchacha20_test.go)0
52 files changed, 864 insertions, 479 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e460293
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+wireguard-go
diff --git a/src/Makefile b/Makefile
index 5b23ecc..5b23ecc 100644
--- a/src/Makefile
+++ b/Makefile
diff --git a/src/build.cmd b/build.cmd
index 52cb883..52cb883 100755
--- a/src/build.cmd
+++ b/build.cmd
diff --git a/src/conn.go b/conn.go
index 6d292d3..fb30ec2 100644
--- a/src/conn.go
+++ b/conn.go
@@ -64,13 +64,13 @@ func unsafeCloseBind(device *Device) error {
return err
}
-func updateBind(device *Device) error {
- device.mutex.Lock()
- defer device.mutex.Unlock()
+func (device *Device) BindUpdate() error {
- netc := &device.net
- netc.mutex.Lock()
- defer netc.mutex.Unlock()
+ device.net.mutex.Lock()
+ defer device.net.mutex.Unlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
// close existing sockets
@@ -78,17 +78,14 @@ func updateBind(device *Device) error {
return err
}
- // assumption: netc.update WaitGroup should be exactly 1
-
// open new sockets
- if device.tun.isUp.Get() {
-
- device.log.Debug.Println("UDP bind updating")
+ if device.isUp.Get() {
// bind to new port
var err error
+ netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port)
if err != nil {
netc.bind = nil
@@ -104,15 +101,15 @@ func updateBind(device *Device) error {
// clear cached source addresses
- for _, peer := range device.peers {
+ for _, peer := range device.peers.keyMap {
peer.mutex.Lock()
+ defer peer.mutex.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
- peer.mutex.Unlock()
}
- // decrease waitgroup to 0
+ // start receiving routines
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
@@ -123,11 +120,9 @@ func updateBind(device *Device) error {
return nil
}
-func closeBind(device *Device) error {
- device.mutex.Lock()
+func (device *Device) BindClose() error {
device.net.mutex.Lock()
err := unsafeCloseBind(device)
device.net.mutex.Unlock()
- device.mutex.Unlock()
return err
}
diff --git a/src/conn_default.go b/conn_default.go
index 5b73c90..5b73c90 100644
--- a/src/conn_default.go
+++ b/conn_default.go
diff --git a/src/conn_linux.go b/conn_linux.go
index cdba74f..cdba74f 100644
--- a/src/conn_linux.go
+++ b/conn_linux.go
diff --git a/src/constants.go b/constants.go
index 71dd98e..71dd98e 100644
--- a/src/constants.go
+++ b/constants.go
diff --git a/src/cookie.go b/cookie.go
index a13ad49..a13ad49 100644
--- a/src/cookie.go
+++ b/cookie.go
diff --git a/src/cookie_test.go b/cookie_test.go
index d745fe7..d745fe7 100644
--- a/src/cookie_test.go
+++ b/cookie_test.go
diff --git a/src/daemon_darwin.go b/daemon_darwin.go
index 913af0e..913af0e 100644
--- a/src/daemon_darwin.go
+++ b/daemon_darwin.go
diff --git a/src/daemon_linux.go b/daemon_linux.go
index e1aaede..e1aaede 100644
--- a/src/daemon_linux.go
+++ b/daemon_linux.go
diff --git a/src/daemon_windows.go b/daemon_windows.go
index d5ec1e8..d5ec1e8 100644
--- a/src/daemon_windows.go
+++ b/daemon_windows.go
diff --git a/device.go b/device.go
new file mode 100644
index 0000000..22e0990
--- /dev/null
+++ b/device.go
@@ -0,0 +1,364 @@
+package main
+
+import (
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+type Device struct {
+ isUp AtomicBool // device is (going) up
+ isClosed AtomicBool // device is closed? (acting as guard)
+ log *Logger
+
+ // synchronized resources (locks acquired in order)
+
+ state struct {
+ mutex sync.Mutex
+ changing AtomicBool
+ current bool
+ }
+
+ net struct {
+ mutex sync.RWMutex
+ bind Bind // bind interface
+ port uint16 // listening port
+ fwmark uint32 // mark value (0 = disabled)
+ }
+
+ noise struct {
+ mutex sync.RWMutex
+ privateKey NoisePrivateKey
+ publicKey NoisePublicKey
+ }
+
+ routing struct {
+ mutex sync.RWMutex
+ table RoutingTable
+ }
+
+ peers struct {
+ mutex sync.RWMutex
+ keyMap map[NoisePublicKey]*Peer
+ }
+
+ // unprotected / "self-synchronising resources"
+
+ indices IndexTable
+ mac CookieChecker
+
+ rate struct {
+ underLoadUntil atomic.Value
+ limiter Ratelimiter
+ }
+
+ pool struct {
+ messageBuffers sync.Pool
+ }
+
+ queue struct {
+ encryption chan *QueueOutboundElement
+ decryption chan *QueueInboundElement
+ handshake chan QueueHandshakeElement
+ }
+
+ signal struct {
+ stop Signal
+ }
+
+ tun struct {
+ device TUNDevice
+ mtu int32
+ }
+}
+
+/* Converts the peer into a "zombie", which remains in the peer map,
+ * but processes no packets and does not exists in the routing table.
+ *
+ * Must hold:
+ * device.peers.mutex : exclusive lock
+ * device.routing : exclusive lock
+ */
+func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
+
+ // stop routing and processing of packets
+
+ device.routing.table.RemovePeer(peer)
+ peer.Stop()
+
+ // remove from peer map
+
+ delete(device.peers.keyMap, key)
+}
+
+func deviceUpdateState(device *Device) {
+
+ // check if state already being updated (guard)
+
+ if device.state.changing.Swap(true) {
+ return
+ }
+
+ // compare to current state of device
+
+ device.state.mutex.Lock()
+
+ newIsUp := device.isUp.Get()
+
+ if newIsUp == device.state.current {
+ device.state.changing.Set(false)
+ device.state.mutex.Unlock()
+ return
+ }
+
+ // change state of device
+
+ switch newIsUp {
+ case true:
+ if err := device.BindUpdate(); err != nil {
+ device.isUp.Set(false)
+ break
+ }
+ device.peers.mutex.Lock()
+ for _, peer := range device.peers.keyMap {
+ peer.Start()
+ }
+ device.peers.mutex.Unlock()
+
+ case false:
+ device.BindClose()
+ device.peers.mutex.Lock()
+ for _, peer := range device.peers.keyMap {
+ peer.Stop()
+ }
+ device.peers.mutex.Unlock()
+ }
+
+ // update state variables
+
+ device.state.current = newIsUp
+ device.state.changing.Set(false)
+ device.state.mutex.Unlock()
+
+ // check for state change in the mean time
+
+ deviceUpdateState(device)
+}
+
+func (device *Device) Up() {
+
+ // closed device cannot be brought up
+
+ if device.isClosed.Get() {
+ return
+ }
+
+ device.state.mutex.Lock()
+ device.isUp.Set(true)
+ device.state.mutex.Unlock()
+ deviceUpdateState(device)
+}
+
+func (device *Device) Down() {
+ device.state.mutex.Lock()
+ device.isUp.Set(false)
+ device.state.mutex.Unlock()
+ deviceUpdateState(device)
+}
+
+func (device *Device) IsUnderLoad() bool {
+
+ // check if currently under load
+
+ now := time.Now()
+ underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
+ if underLoad {
+ device.rate.underLoadUntil.Store(now.Add(time.Second))
+ return true
+ }
+
+ // check if recently under load
+
+ until := device.rate.underLoadUntil.Load().(time.Time)
+ return until.After(now)
+}
+
+func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
+
+ // lock required resources
+
+ device.noise.mutex.Lock()
+ defer device.noise.mutex.Unlock()
+
+ device.routing.mutex.Lock()
+ defer device.routing.mutex.Unlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ for _, peer := range device.peers.keyMap {
+ peer.handshake.mutex.RLock()
+ defer peer.handshake.mutex.RUnlock()
+ }
+
+ // remove peers with matching public keys
+
+ publicKey := sk.publicKey()
+ for key, peer := range device.peers.keyMap {
+ if peer.handshake.remoteStatic.Equals(publicKey) {
+ unsafeRemovePeer(device, peer, key)
+ }
+ }
+
+ // update key material
+
+ device.noise.privateKey = sk
+ device.noise.publicKey = publicKey
+ device.mac.Init(publicKey)
+
+ // do static-static DH pre-computations
+
+ rmKey := device.noise.privateKey.IsZero()
+
+ for key, peer := range device.peers.keyMap {
+
+ hs := &peer.handshake
+
+ if rmKey {
+ hs.precomputedStaticStatic = [NoisePublicKeySize]byte{}
+ } else {
+ hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic)
+ }
+
+ if isZero(hs.precomputedStaticStatic[:]) {
+ unsafeRemovePeer(device, peer, key)
+ }
+ }
+
+ return nil
+}
+
+func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
+ return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
+}
+
+func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
+ device.pool.messageBuffers.Put(msg)
+}
+
+func NewDevice(tun TUNDevice, logger *Logger) *Device {
+ device := new(Device)
+
+ device.isUp.Set(false)
+ device.isClosed.Set(false)
+
+ device.log = logger
+ device.tun.device = tun
+ device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+
+ // initialize anti-DoS / anti-scanning features
+
+ device.rate.limiter.Init()
+ device.rate.underLoadUntil.Store(time.Time{})
+
+ // initialize noise & crypt-key routine
+
+ device.indices.Init()
+ device.routing.table.Reset()
+
+ // setup buffer pool
+
+ device.pool.messageBuffers = sync.Pool{
+ New: func() interface{} {
+ return new([MaxMessageSize]byte)
+ },
+ }
+
+ // create queues
+
+ device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
+ device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
+ device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
+
+ // prepare signals
+
+ device.signal.stop = NewSignal()
+
+ // prepare net
+
+ device.net.port = 0
+ device.net.bind = nil
+
+ // start workers
+
+ for i := 0; i < runtime.NumCPU(); i += 1 {
+ go device.RoutineEncryption()
+ go device.RoutineDecryption()
+ go device.RoutineHandshake()
+ }
+
+ go device.RoutineReadFromTUN()
+ go device.RoutineTUNEventReader()
+ go device.rate.limiter.RoutineGarbageCollector(device.signal.stop)
+
+ return device
+}
+
+func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
+ device.peers.mutex.RLock()
+ defer device.peers.mutex.RUnlock()
+
+ return device.peers.keyMap[pk]
+}
+
+func (device *Device) RemovePeer(key NoisePublicKey) {
+ device.noise.mutex.Lock()
+ defer device.noise.mutex.Unlock()
+
+ device.routing.mutex.Lock()
+ defer device.routing.mutex.Unlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ // stop peer and remove from routing
+
+ peer, ok := device.peers.keyMap[key]
+ if ok {
+ unsafeRemovePeer(device, peer, key)
+ }
+}
+
+func (device *Device) RemoveAllPeers() {
+
+ device.routing.mutex.Lock()
+ defer device.routing.mutex.Unlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ for key, peer := range device.peers.keyMap {
+ println("rm", peer.String())
+ unsafeRemovePeer(device, peer, key)
+ }
+
+ device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+}
+
+func (device *Device) Close() {
+ device.log.Info.Println("Device closing")
+ if device.isClosed.Swap(true) {
+ return
+ }
+ device.signal.stop.Broadcast()
+ device.tun.device.Close()
+ device.BindClose()
+ device.isUp.Set(false)
+ device.RemoveAllPeers()
+ device.log.Info.Println("Interface closed")
+}
+
+func (device *Device) Wait() chan struct{} {
+ return device.signal.stop.Wait()
+}
diff --git a/src/helper_test.go b/helper_test.go
index 8548121..41e6b72 100644
--- a/src/helper_test.go
+++ b/helper_test.go
@@ -28,8 +28,8 @@ func (tun *DummyTUN) MTU() (int, error) {
return tun.mtu, nil
}
-func (tun *DummyTUN) Write(d []byte) (int, error) {
- tun.packets <- d
+func (tun *DummyTUN) Write(d []byte, offset int) (int, error) {
+ tun.packets <- d[offset:]
return len(d), nil
}
@@ -41,9 +41,9 @@ func (tun *DummyTUN) Events() chan TUNEvent {
return tun.events
}
-func (tun *DummyTUN) Read(d []byte) (int, error) {
+func (tun *DummyTUN) Read(d []byte, offset int) (int, error) {
t := <-tun.packets
- copy(d, t)
+ copy(d[offset:], t)
return len(t), nil
}
diff --git a/src/index.go b/index.go
index 1ba040e..1ba040e 100644
--- a/src/index.go
+++ b/index.go
diff --git a/src/ip.go b/ip.go
index 752a404..752a404 100644
--- a/src/ip.go
+++ b/ip.go
diff --git a/src/kdf_test.go b/kdf_test.go
index a89dacc..a89dacc 100644
--- a/src/kdf_test.go
+++ b/kdf_test.go
diff --git a/src/keypair.go b/keypair.go
index 7e5297b..283cb92 100644
--- a/src/keypair.go
+++ b/keypair.go
@@ -38,5 +38,7 @@ func (kp *KeyPairs) Current() *KeyPair {
}
func (device *Device) DeleteKeyPair(key *KeyPair) {
- device.indices.Delete(key.localIndex)
+ if key != nil {
+ device.indices.Delete(key.localIndex)
+ }
}
diff --git a/src/logger.go b/logger.go
index 0872ef9..0872ef9 100644
--- a/src/logger.go
+++ b/logger.go
diff --git a/src/main.go b/main.go
index b12bb09..b12bb09 100644
--- a/src/main.go
+++ b/main.go
diff --git a/src/misc.go b/misc.go
index 80e33f6..80e33f6 100644
--- a/src/misc.go
+++ b/misc.go
diff --git a/src/noise_helpers.go b/noise_helpers.go
index 24302c0..1e2de5f 100644
--- a/src/noise_helpers.go
+++ b/noise_helpers.go
@@ -3,6 +3,7 @@ package main
import (
"crypto/hmac"
"crypto/rand"
+ "crypto/subtle"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519"
"hash"
@@ -58,11 +59,11 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
}
func isZero(val []byte) bool {
- var acc byte
+ acc := 1
for _, b := range val {
- acc |= b
+ acc &= subtle.ConstantTimeByteEq(b, 0)
}
- return acc == 0
+ return acc == 1
}
func setZero(arr []byte) {
diff --git a/src/noise_protocol.go b/noise_protocol.go
index 2f9e1d5..c9713c0 100644
--- a/src/noise_protocol.go
+++ b/noise_protocol.go
@@ -121,6 +121,15 @@ func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
hsh.Reset()
}
+func (h *Handshake) Clear() {
+ setZero(h.localEphemeral[:])
+ setZero(h.remoteEphemeral[:])
+ setZero(h.chainKey[:])
+ setZero(h.hash[:])
+ h.localIndex = 0
+ h.state = HandshakeZeroed
+}
+
func (h *Handshake) mixHash(data []byte) {
mixHash(&h.hash, &h.hash, data)
}
@@ -137,6 +146,10 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
+
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
+
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
@@ -187,7 +200,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
ss[:],
)
aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
+ aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:])
}()
handshake.mixHash(msg.Static[:])
@@ -212,16 +225,19 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
}
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
- if msg.Type != MessageInitiationType {
- return nil
- }
-
var (
hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
- mixHash(&hash, &InitialHash, device.publicKey[:])
+ if msg.Type != MessageInitiationType {
+ return nil
+ }
+
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
+
+ mixHash(&hash, &InitialHash, device.noise.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:])
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
@@ -231,7 +247,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
- ss := device.privateKey.sharedSecret(msg.Ephemeral)
+ ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
@@ -386,7 +402,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
ok := func() bool {
- // read lock handshake
+ // lock handshake state
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
@@ -395,6 +411,11 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return false
}
+ // lock private key for reading
+
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
+
// finish 3-way DH
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
@@ -407,7 +428,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
}()
func() {
- ss := device.privateKey.sharedSecret(msg.Ephemeral)
+ ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
@@ -425,7 +446,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
)
mixHash(&hash, &hash, tau[:])
- // authenticate
+ // authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
diff --git a/src/noise_test.go b/noise_test.go
index 0d7f0e9..5e9d44b 100644
--- a/src/noise_test.go
+++ b/noise_test.go
@@ -31,8 +31,8 @@ func TestNoiseHandshake(t *testing.T) {
defer dev1.Close()
defer dev2.Close()
- peer1, _ := dev2.NewPeer(dev1.privateKey.publicKey())
- peer2, _ := dev1.NewPeer(dev2.privateKey.publicKey())
+ peer1, _ := dev2.NewPeer(dev1.noise.privateKey.publicKey())
+ peer2, _ := dev1.NewPeer(dev2.noise.privateKey.publicKey())
assertEqual(
t,
diff --git a/src/noise_types.go b/noise_types.go
index 1a944df..1a944df 100644
--- a/src/noise_types.go
+++ b/noise_types.go
diff --git a/src/peer.go b/peer.go
index f582556..d8bb2bf 100644
--- a/src/peer.go
+++ b/peer.go
@@ -8,25 +8,32 @@ import (
"time"
)
+const (
+ PeerRoutineNumber = 4
+)
+
type Peer struct {
- id uint
+ isRunning AtomicBool
mutex sync.RWMutex
persistentKeepaliveInterval uint64
keyPairs KeyPairs
handshake Handshake
device *Device
endpoint Endpoint
- stats struct {
+
+ stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
}
+
time struct {
mutex sync.RWMutex
lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake
nextKeepalive time.Time
}
+
signal struct {
newKeyPair Signal // size 1, new key pair was generated
handshakeCompleted Signal // size 1, handshake completed
@@ -34,30 +41,62 @@ type Peer struct {
flushNonceQueue Signal // size 1, empty queued packets
messageSend Signal // size 1, message was send to peer
messageReceived Signal // size 1, authenticated message recv
- stop Signal // size 0, stop all goroutines
}
+
timer struct {
+
// state related to WireGuard timers
- keepalivePersistent Timer // set for persistent keepalives
- keepalivePassive Timer // set upon recieving messages
- newHandshake Timer // begin a new handshake (stale)
+ keepalivePersistent Timer // set for persistent keep-alive
+ keepalivePassive Timer // set upon receiving messages
zeroAllKeys Timer // zero all key material
+ handshakeNew Timer // begin a new handshake (stale)
handshakeDeadline Timer // complete handshake timeout
handshakeTimeout Timer // current handshake message timeout
sendLastMinuteHandshake bool
needAnotherKeepalive bool
}
+
queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work
}
+
+ routines struct {
+ mutex sync.Mutex // held when stopping / starting routines
+ starting sync.WaitGroup // routines pending start
+ stopping sync.WaitGroup // routines pending stop
+ stop Signal // size 0, stop all go-routines in peer
+ }
+
mac CookieGenerator
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
+
+ if device.isClosed.Get() {
+ return nil, errors.New("Device closed")
+ }
+
+ // lock resources
+
+ device.state.mutex.Lock()
+ defer device.state.mutex.Unlock()
+
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ // check if over limit
+
+ if len(device.peers.keyMap) >= MaxPeers {
+ return nil, errors.New("Too many peers")
+ }
+
// create peer
peer := new(Peer)
@@ -66,66 +105,46 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.mac.Init(pk)
peer.device = device
+ peer.isRunning.Set(false)
+ peer.timer.zeroAllKeys = NewTimer()
peer.timer.keepalivePersistent = NewTimer()
peer.timer.keepalivePassive = NewTimer()
- peer.timer.newHandshake = NewTimer()
- peer.timer.zeroAllKeys = NewTimer()
+ peer.timer.handshakeNew = NewTimer()
peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer()
- // assign id for debugging
-
- device.mutex.Lock()
- peer.id = device.idCounter
- device.idCounter += 1
-
- // check if over limit
-
- if len(device.peers) >= MaxPeers {
- return nil, errors.New("Too many peers")
- }
-
// map public key
- _, ok := device.peers[pk]
+ _, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("Adding existing peer")
}
- device.peers[pk] = peer
- device.mutex.Unlock()
+ device.peers.keyMap[pk] = peer
- // precompute DH
+ // pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.remoteStatic = pk
- handshake.precomputedStaticStatic =
- device.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk)
handshake.mutex.Unlock()
// reset endpoint
peer.endpoint = nil
- // prepare queuing
-
- peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
-
// prepare signaling & routines
- peer.signal.stop = NewSignal()
- peer.signal.newKeyPair = NewSignal()
- peer.signal.handshakeBegin = NewSignal()
- peer.signal.handshakeCompleted = NewSignal()
- peer.signal.flushNonceQueue = NewSignal()
+ peer.routines.mutex.Lock()
+ peer.routines.stop = NewSignal()
+ peer.routines.mutex.Unlock()
- go peer.RoutineNonce()
- go peer.RoutineTimerHandler()
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
+ // start peer
+
+ if peer.device.isUp.Get() {
+ peer.Start()
+ }
return peer, nil
}
@@ -133,32 +152,143 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
+
+ if peer.device.net.bind == nil {
+ return errors.New("No bind")
+ }
+
peer.mutex.RLock()
defer peer.mutex.RUnlock()
+
if peer.endpoint == nil {
return errors.New("No known endpoint for peer")
}
+
return peer.device.net.bind.Send(buffer, peer.endpoint)
}
-/* Returns a short string identification for logging
+/* Returns a short string identifier for logging
*/
func (peer *Peer) String() string {
if peer.endpoint == nil {
return fmt.Sprintf(
- "peer(%d unknown %s)",
- peer.id,
+ "peer(unknown %s)",
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
return fmt.Sprintf(
- "peer(%d %s %s)",
- peer.id,
+ "peer(%s %s)",
peer.endpoint.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
-func (peer *Peer) Close() {
- peer.signal.stop.Broadcast()
+func (peer *Peer) Start() {
+
+ // should never start a peer on a closed device
+
+ if peer.device.isClosed.Get() {
+ return
+ }
+
+ // prevent simultaneous start/stop operations
+
+ peer.routines.mutex.Lock()
+ defer peer.routines.mutex.Unlock()
+
+ if peer.isRunning.Get() {
+ return
+ }
+
+ peer.device.log.Debug.Println("Starting:", peer.String())
+
+ // sanity check : these should be 0
+
+ peer.routines.starting.Wait()
+ peer.routines.stopping.Wait()
+
+ // prepare queues and signals
+
+ peer.signal.newKeyPair = NewSignal()
+ peer.signal.handshakeBegin = NewSignal()
+ peer.signal.handshakeCompleted = NewSignal()
+ peer.signal.flushNonceQueue = NewSignal()
+
+ peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
+
+ peer.routines.stop = NewSignal()
+ peer.isRunning.Set(true)
+
+ // wait for routines to start
+
+ peer.routines.starting.Add(PeerRoutineNumber)
+ peer.routines.stopping.Add(PeerRoutineNumber)
+
+ go peer.RoutineNonce()
+ go peer.RoutineTimerHandler()
+ go peer.RoutineSequentialSender()
+ go peer.RoutineSequentialReceiver()
+
+ peer.routines.starting.Wait()
+ peer.isRunning.Set(true)
+}
+
+func (peer *Peer) Stop() {
+
+ // prevent simultaneous start/stop operations
+
+ peer.routines.mutex.Lock()
+ defer peer.routines.mutex.Unlock()
+
+ if !peer.isRunning.Swap(false) {
+ return
+ }
+
+ device := peer.device
+ device.log.Debug.Println("Stopping:", peer.String())
+
+ // stop & wait for ongoing peer routines
+
+ peer.routines.stop.Broadcast()
+ peer.routines.starting.Wait()
+ peer.routines.stopping.Wait()
+
+ // stop timers
+
+ peer.timer.keepalivePersistent.Stop()
+ peer.timer.keepalivePassive.Stop()
+ peer.timer.zeroAllKeys.Stop()
+ peer.timer.handshakeNew.Stop()
+ peer.timer.handshakeDeadline.Stop()
+ peer.timer.handshakeTimeout.Stop()
+
+ // close queues
+
+ close(peer.queue.nonce)
+ close(peer.queue.outbound)
+ close(peer.queue.inbound)
+
+ // clear key pairs
+
+ kp := &peer.keyPairs
+ kp.mutex.Lock()
+
+ device.DeleteKeyPair(kp.previous)
+ device.DeleteKeyPair(kp.current)
+ device.DeleteKeyPair(kp.next)
+
+ kp.previous = nil
+ kp.current = nil
+ kp.next = nil
+ kp.mutex.Unlock()
+
+ // clear handshake state
+
+ hs := &peer.handshake
+ hs.mutex.Lock()
+ device.indices.Delete(hs.localIndex)
+ hs.Clear()
+ hs.mutex.Unlock()
}
diff --git a/src/ratelimiter.go b/ratelimiter.go
index 6e5f005..6e5f005 100644
--- a/src/ratelimiter.go
+++ b/ratelimiter.go
diff --git a/src/ratelimiter_test.go b/ratelimiter_test.go
index 13b6a23..13b6a23 100644
--- a/src/ratelimiter_test.go
+++ b/ratelimiter_test.go
diff --git a/src/receive.go b/receive.go
index dbd2813..1f44df2 100644
--- a/src/receive.go
+++ b/receive.go
@@ -123,7 +123,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
- return
+ panic("invalid IP version")
}
if err != nil {
@@ -184,9 +184,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
// add to decryption queues
- device.addToDecryptionQueue(device.queue.decryption, elem)
- device.addToInboundQueue(peer.queue.inbound, elem)
- buffer = device.GetMessageBuffer()
+ if peer.isRunning.Get() {
+ device.addToDecryptionQueue(device.queue.decryption, elem)
+ device.addToInboundQueue(peer.queue.inbound, elem)
+ buffer = device.GetMessageBuffer()
+ }
continue
@@ -308,13 +310,20 @@ func (device *Device) RoutineHandshake() {
return
}
- // lookup peer and consume response
+ // lookup peer from index
entry := device.indices.Lookup(reply.Receiver)
+
if entry.peer == nil {
- return
+ continue
+ }
+
+ // consume reply
+
+ if peer := entry.peer; peer.isRunning.Get() {
+ peer.mac.ConsumeReply(&reply)
}
- entry.peer.mac.ConsumeReply(&reply)
+
continue
case MessageInitiationType, MessageResponseType:
@@ -323,7 +332,7 @@ func (device *Device) RoutineHandshake() {
if !device.mac.CheckMAC1(elem.packet) {
logDebug.Println("Received packet with invalid mac1")
- return
+ continue
}
// endpoints destination address is the source of the datagram
@@ -347,7 +356,7 @@ func (device *Device) RoutineHandshake() {
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil {
logError.Println("Failed to create cookie reply:", err)
- return
+ continue
}
// marshal and send reply
@@ -363,7 +372,7 @@ func (device *Device) RoutineHandshake() {
// check ratelimiter
- if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
+ if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
continue
}
}
@@ -486,19 +495,23 @@ func (device *Device) RoutineHandshake() {
func (peer *Peer) RoutineSequentialReceiver() {
+ defer peer.routines.stopping.Done()
+
device := peer.device
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
- logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
+ logDebug.Println("Routine, sequential receiver, started for peer", peer.String())
+
+ peer.routines.starting.Done()
for {
select {
- case <-peer.signal.stop.Wait():
- logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
+ case <-peer.routines.stop.Wait():
+ logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String())
return
case elem := <-peer.queue.inbound:
@@ -572,7 +585,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.routingTable.LookupIPv4(src) != peer {
+ if device.routing.table.LookupIPv4(src) != peer {
logInfo.Println(
"IPv4 packet with disallowed source address from",
peer.String(),
@@ -600,7 +613,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.routingTable.LookupIPv6(src) != peer {
+ if device.routing.table.LookupIPv6(src) != peer {
logInfo.Println(
"IPv6 packet with disallowed source address from",
peer.String(),
diff --git a/src/replay.go b/replay.go
index 5d42860..5d42860 100644
--- a/src/replay.go
+++ b/replay.go
diff --git a/src/replay_test.go b/replay_test.go
index 228fce6..228fce6 100644
--- a/src/replay_test.go
+++ b/replay_test.go
diff --git a/src/routing.go b/routing.go
index 2a2e237..2a2e237 100644
--- a/src/routing.go
+++ b/routing.go
diff --git a/src/send.go b/send.go
index 9537f5e..7488d3a 100644
--- a/src/send.go
+++ b/send.go
@@ -151,14 +151,14 @@ func (device *Device) RoutineReadFromTUN() {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.routingTable.LookupIPv4(dst)
+ peer = device.routing.table.LookupIPv4(dst)
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.routingTable.LookupIPv6(dst)
+ peer = device.routing.table.LookupIPv6(dst)
default:
logDebug.Println("Received packet with unknown IP version")
@@ -170,9 +170,11 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue
- peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
- addToOutboundQueue(peer.queue.nonce, elem)
- elem = device.NewOutboundElement()
+ if peer.isRunning.Get() {
+ peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
+ addToOutboundQueue(peer.queue.nonce, elem)
+ elem = device.NewOutboundElement()
+ }
}
}
@@ -185,14 +187,18 @@ func (device *Device) RoutineReadFromTUN() {
func (peer *Peer) RoutineNonce() {
var keyPair *KeyPair
+ defer peer.routines.stopping.Done()
+
device := peer.device
logDebug := device.log.Debug
logDebug.Println("Routine, nonce worker, started for peer", peer.String())
+ peer.routines.starting.Done()
+
for {
NextPacket:
select {
- case <-peer.signal.stop.Wait():
+ case <-peer.routines.stop.Wait():
return
case elem := <-peer.queue.nonce:
@@ -217,7 +223,7 @@ func (peer *Peer) RoutineNonce() {
logDebug.Println("Clearing queue for", peer.String())
peer.FlushNonceQueue()
goto NextPacket
- case <-peer.signal.stop.Wait():
+ case <-peer.routines.stop.Wait():
return
}
}
@@ -309,15 +315,20 @@ func (device *Device) RoutineEncryption() {
* The routine terminates then the outbound queue is closed.
*/
func (peer *Peer) RoutineSequentialSender() {
+
+ defer peer.routines.stopping.Done()
+
device := peer.device
logDebug := device.log.Debug
logDebug.Println("Routine, sequential sender, started for", peer.String())
+ peer.routines.starting.Done()
+
for {
select {
- case <-peer.signal.stop.Wait():
+ case <-peer.routines.stop.Wait():
logDebug.Println(
"Routine, sequential sender, stopped for", peer.String())
return
diff --git a/src/signal.go b/signal.go
index 2cefad4..2cefad4 100644
--- a/src/signal.go
+++ b/signal.go
diff --git a/src/device.go b/src/device.go
deleted file mode 100644
index a3461ad..0000000
--- a/src/device.go
+++ /dev/null
@@ -1,221 +0,0 @@
-package main
-
-import (
- "runtime"
- "sync"
- "sync/atomic"
- "time"
-)
-
-type Device struct {
- closed AtomicBool // device is closed? (acting as guard)
- log *Logger // collection of loggers for levels
- idCounter uint // for assigning debug ids to peers
- fwMark uint32
- tun struct {
- device TUNDevice
- isUp AtomicBool
- mtu int32
- }
- pool struct {
- messageBuffers sync.Pool
- }
- net struct {
- mutex sync.RWMutex
- bind Bind // bind interface
- port uint16 // listening port
- fwmark uint32 // mark value (0 = disabled)
- }
- mutex sync.RWMutex
- privateKey NoisePrivateKey
- publicKey NoisePublicKey
- routingTable RoutingTable
- indices IndexTable
- queue struct {
- encryption chan *QueueOutboundElement
- decryption chan *QueueInboundElement
- handshake chan QueueHandshakeElement
- }
- signal struct {
- stop Signal
- }
- underLoadUntil atomic.Value
- ratelimiter Ratelimiter
- peers map[NoisePublicKey]*Peer
- mac CookieChecker
-}
-
-/* Warning:
- * The caller must hold the device mutex (write lock)
- */
-func removePeerUnsafe(device *Device, key NoisePublicKey) {
- peer, ok := device.peers[key]
- if !ok {
- return
- }
- peer.mutex.Lock()
- device.routingTable.RemovePeer(peer)
- delete(device.peers, key)
- peer.Close()
-}
-
-func (device *Device) IsUnderLoad() bool {
-
- // check if currently under load
-
- now := time.Now()
- underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
- if underLoad {
- device.underLoadUntil.Store(now.Add(time.Second))
- return true
- }
-
- // check if recently under load
-
- until := device.underLoadUntil.Load().(time.Time)
- return until.After(now)
-}
-
-func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
- device.mutex.Lock()
- defer device.mutex.Unlock()
-
- // remove peers with matching public keys
-
- publicKey := sk.publicKey()
- for key, peer := range device.peers {
- h := &peer.handshake
- h.mutex.RLock()
- if h.remoteStatic.Equals(publicKey) {
- removePeerUnsafe(device, key)
- }
- h.mutex.RUnlock()
- }
-
- // update key material
-
- device.privateKey = sk
- device.publicKey = publicKey
- device.mac.Init(publicKey)
-
- // do DH precomputations
-
- rmKey := device.privateKey.IsZero()
-
- for key, peer := range device.peers {
- h := &peer.handshake
- h.mutex.Lock()
- if rmKey {
- h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
- } else {
- h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
- if isZero(h.precomputedStaticStatic[:]) {
- removePeerUnsafe(device, key)
- }
- }
- h.mutex.Unlock()
- }
-
- return nil
-}
-
-func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
- return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
-}
-
-func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
- device.pool.messageBuffers.Put(msg)
-}
-
-func NewDevice(tun TUNDevice, logger *Logger) *Device {
- device := new(Device)
- device.mutex.Lock()
- defer device.mutex.Unlock()
-
- device.log = logger
- device.peers = make(map[NoisePublicKey]*Peer)
- device.tun.device = tun
- device.tun.isUp.Set(false)
-
- device.indices.Init()
- device.ratelimiter.Init()
-
- device.routingTable.Reset()
- device.underLoadUntil.Store(time.Time{})
-
- // setup buffer pool
-
- device.pool.messageBuffers = sync.Pool{
- New: func() interface{} {
- return new([MaxMessageSize]byte)
- },
- }
-
- // create queues
-
- device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
- device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
- device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
-
- // prepare signals
-
- device.signal.stop = NewSignal()
-
- // prepare net
-
- device.net.port = 0
- device.net.bind = nil
-
- // start workers
-
- for i := 0; i < runtime.NumCPU(); i += 1 {
- go device.RoutineEncryption()
- go device.RoutineDecryption()
- go device.RoutineHandshake()
- }
-
- go device.RoutineReadFromTUN()
- go device.RoutineTUNEventReader()
- go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
-
- return device
-}
-
-func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
- device.mutex.RLock()
- defer device.mutex.RUnlock()
- return device.peers[pk]
-}
-
-func (device *Device) RemovePeer(key NoisePublicKey) {
- device.mutex.Lock()
- defer device.mutex.Unlock()
- removePeerUnsafe(device, key)
-}
-
-func (device *Device) RemoveAllPeers() {
- device.mutex.Lock()
- defer device.mutex.Unlock()
-
- for key, peer := range device.peers {
- peer.mutex.Lock()
- delete(device.peers, key)
- peer.Close()
- peer.mutex.Unlock()
- }
-}
-
-func (device *Device) Close() {
- if device.closed.Swap(true) {
- return
- }
- device.log.Info.Println("Closing device")
- device.RemoveAllPeers()
- device.signal.stop.Broadcast()
- device.tun.device.Close()
- closeBind(device)
-}
-
-func (device *Device) Wait() chan struct{} {
- return device.signal.stop.Wait()
-}
diff --git a/src/tai64.go b/tai64.go
index 2299a37..2299a37 100644
--- a/src/tai64.go
+++ b/tai64.go
diff --git a/src/tests/netns.sh b/tests/netns.sh
index 02d428b..6c47a44 100755
--- a/src/tests/netns.sh
+++ b/tests/netns.sh
@@ -80,11 +80,11 @@ pp ip netns add $netns2
ip0 link set up dev lo
# ip0 link add dev wg1 type wireguard
-n0 $program wg1
+n0 $program -f wg1 &
ip0 link set wg1 netns $netns1
# ip0 link add dev wg1 type wireguard
-n0 $program wg2
+n0 $program -f wg2 &
ip0 link set wg2 netns $netns2
key1="$(pp wg genkey)"
diff --git a/src/timer.go b/timer.go
index 3def253..f00ca49 100644
--- a/src/timer.go
+++ b/timer.go
@@ -43,12 +43,6 @@ func (t *Timer) Reset(dur time.Duration) {
t.Start(dur)
}
-func (t *Timer) Push(dur time.Duration) {
- if t.pending.Get() {
- t.Reset(dur)
- }
-}
-
func (t *Timer) Wait() <-chan time.Time {
return t.timer.C
}
diff --git a/src/timers.go b/timers.go
index ee47393..7092688 100644
--- a/src/timers.go
+++ b/timers.go
@@ -8,6 +8,12 @@ import (
"time"
)
+/* NOTE:
+ * Notion of validity
+ *
+ *
+ */
+
/* Called when a new authenticated message has been send
*
*/
@@ -44,25 +50,25 @@ func (peer *Peer) KeepKeyFreshReceiving() {
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send {
// do a last minute attempt at initiating a new handshake
- peer.signal.handshakeBegin.Send()
peer.timer.sendLastMinuteHandshake = true
+ peer.signal.handshakeBegin.Send()
}
}
/* Queues a keep-alive if no packets are queued for peer
*/
func (peer *Peer) SendKeepAlive() bool {
+ if len(peer.queue.nonce) != 0 {
+ return false
+ }
elem := peer.device.NewOutboundElement()
elem.packet = nil
- if len(peer.queue.nonce) == 0 {
- select {
- case peer.queue.nonce <- elem:
- return true
- default:
- return false
- }
+ select {
+ case peer.queue.nonce <- elem:
+ return true
+ default:
+ return false
}
- return true
}
/* Event:
@@ -70,9 +76,7 @@ func (peer *Peer) SendKeepAlive() bool {
*/
func (peer *Peer) TimerDataSent() {
peer.timer.keepalivePassive.Stop()
- if peer.timer.newHandshake.Pending() {
- peer.timer.newHandshake.Reset(NewHandshakeTime)
- }
+ peer.timer.handshakeNew.Start(NewHandshakeTime)
}
/* Event:
@@ -91,7 +95,7 @@ func (peer *Peer) TimerDataReceived() {
* Any (authenticated) packet received
*/
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
- peer.timer.newHandshake.Stop()
+ peer.timer.handshakeNew.Stop()
}
/* Event:
@@ -115,10 +119,6 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
* - First transport message under the "next" key
*/
func (peer *Peer) TimerHandshakeComplete() {
- atomic.StoreInt64(
- &peer.stats.lastHandshakeNano,
- time.Now().UnixNano(),
- )
peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
}
@@ -139,16 +139,86 @@ func (peer *Peer) TimerEphemeralKeyCreated() {
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
}
+/* Sends a new handshake initiation message to the peer (endpoint)
+ */
+func (peer *Peer) sendNewHandshake() error {
+
+ // temporarily disable the handshake complete signal
+
+ peer.signal.handshakeCompleted.Disable()
+
+ // create initiation message
+
+ msg, err := peer.device.CreateMessageInitiation(peer)
+ if err != nil {
+ return err
+ }
+
+ // marshal handshake message
+
+ var buff [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, msg)
+ packet := writer.Bytes()
+ peer.mac.AddMacs(packet)
+
+ // send to endpoint
+
+ peer.TimerAnyAuthenticatedPacketTraversal()
+
+ err = peer.SendBuffer(packet)
+ if err == nil {
+ peer.signal.handshakeCompleted.Enable()
+ }
+
+ // set timeout
+
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
+
+ peer.timer.keepalivePassive.Stop()
+ peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
+
+ return err
+}
+
func (peer *Peer) RoutineTimerHandler() {
+
+ defer peer.routines.stopping.Done()
+
device := peer.device
logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String())
+ // reset all timers
+
+ peer.timer.keepalivePassive.Stop()
+ peer.timer.handshakeDeadline.Stop()
+ peer.timer.handshakeTimeout.Stop()
+ peer.timer.handshakeNew.Stop()
+ peer.timer.zeroAllKeys.Stop()
+
+ interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
+ if interval > 0 {
+ duration := time.Duration(interval) * time.Second
+ peer.timer.keepalivePersistent.Reset(duration)
+ }
+
+ // signal synchronised setup complete
+
+ peer.routines.starting.Done()
+
+ // handle timer events
+
for {
select {
+ /* stopping */
+
+ case <-peer.routines.stop.Wait():
+ return
+
/* timers */
// keep-alive
@@ -158,6 +228,7 @@ func (peer *Peer) RoutineTimerHandler() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
logDebug.Println("Sending keep-alive to", peer.String())
+ peer.timer.keepalivePassive.Stop()
peer.SendKeepAlive()
}
@@ -168,8 +239,8 @@ func (peer *Peer) RoutineTimerHandler() {
peer.SendKeepAlive()
if peer.timer.needAnotherKeepalive {
- peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
peer.timer.needAnotherKeepalive = false
+ peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
}
// clear key material timer
@@ -203,17 +274,12 @@ func (peer *Peer) RoutineTimerHandler() {
// zero out handshake
device.indices.Delete(hs.localIndex)
-
- hs.localIndex = 0
- setZero(hs.localEphemeral[:])
- setZero(hs.remoteEphemeral[:])
- setZero(hs.chainKey[:])
- setZero(hs.hash[:])
+ hs.Clear()
hs.mutex.Unlock()
// handshake timers
- case <-peer.timer.newHandshake.Wait():
+ case <-peer.timer.handshakeNew.Wait():
logInfo.Println("Retrying handshake with", peer.String())
peer.signal.handshakeBegin.Send()
@@ -232,7 +298,7 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
- "Failed to send handshake to peer:", peer.String())
+ "Failed to send handshake to peer:", peer.String(), "(", err, ")")
}
case <-peer.timer.handshakeDeadline.Wait():
@@ -248,9 +314,6 @@ func (peer *Peer) RoutineTimerHandler() {
/* signals */
- case <-peer.signal.stop.Wait():
- return
-
case <-peer.signal.handshakeBegin.Wait():
peer.signal.handshakeBegin.Disable()
@@ -258,7 +321,7 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
- "Failed to send handshake to peer:", peer.String())
+ "Failed to send handshake to peer:", peer.String(), "(", err, ")")
}
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
@@ -268,48 +331,16 @@ func (peer *Peer) RoutineTimerHandler() {
logInfo.Println(
"Handshake completed for:", peer.String())
+ atomic.StoreInt64(
+ &peer.stats.lastHandshakeNano,
+ time.Now().UnixNano(),
+ )
+
peer.timer.handshakeTimeout.Stop()
peer.timer.handshakeDeadline.Stop()
peer.signal.handshakeBegin.Enable()
- }
- }
-}
-
-/* Sends a new handshake initiation message to the peer (endpoint)
- */
-func (peer *Peer) sendNewHandshake() error {
-
- // temporarily disable the handshake complete signal
-
- peer.signal.handshakeCompleted.Disable()
- // create initiation message
-
- msg, err := peer.device.CreateMessageInitiation(peer)
- if err != nil {
- return err
- }
-
- // marshal handshake message
-
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
- binary.Write(writer, binary.LittleEndian, msg)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
-
- // send to endpoint
-
- err = peer.SendBuffer(packet)
- if err == nil {
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.signal.handshakeCompleted.Enable()
+ peer.timer.sendLastMinuteHandshake = false
+ }
}
-
- // set timeout
-
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
- peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
-
- return err
}
diff --git a/src/trie.go b/trie.go
index 405ffc3..405ffc3 100644
--- a/src/trie.go
+++ b/trie.go
diff --git a/src/trie_rand_test.go b/trie_rand_test.go
index 840d269..840d269 100644
--- a/src/trie_rand_test.go
+++ b/trie_rand_test.go
diff --git a/src/trie_test.go b/trie_test.go
index 9d53df3..9d53df3 100644
--- a/src/trie_test.go
+++ b/trie_test.go
diff --git a/src/tun.go b/tun.go
index 394ba9a..6259f33 100644
--- a/src/tun.go
+++ b/tun.go
@@ -45,26 +45,14 @@ func (device *Device) RoutineTUNEventReader() {
}
}
- if event&TUNEventUp != 0 {
- if !device.tun.isUp.Get() {
- // begin listening for incomming datagrams
- logInfo.Println("Interface set up")
- device.tun.isUp.Set(true)
- if err := updateBind(device); err != nil {
- logInfo.Println("Failed to bind UDP socket:", err)
- }
- }
+ if event&TUNEventUp != 0 && !device.isUp.Get() {
+ logInfo.Println("Interface set up")
+ device.Up()
}
- if event&TUNEventDown != 0 {
- if device.tun.isUp.Get() {
- // stop listening for incomming datagrams
- logInfo.Println("Interface set down")
- device.tun.isUp.Set(false)
- if err := closeBind(device); err != nil {
- logInfo.Println("Failed to close UDP socket:", err)
- }
- }
+ if event&TUNEventDown != 0 && device.isUp.Get() {
+ logInfo.Println("Interface set down")
+ device.Down()
}
}
}
diff --git a/src/tun_darwin.go b/tun_darwin.go
index 87f6af6..87f6af6 100644
--- a/src/tun_darwin.go
+++ b/tun_darwin.go
diff --git a/src/tun_linux.go b/tun_linux.go
index daa2462..daa2462 100644
--- a/src/tun_linux.go
+++ b/tun_linux.go
diff --git a/src/tun_windows.go b/tun_windows.go
index 0711032..0711032 100644
--- a/src/tun_windows.go
+++ b/tun_windows.go
diff --git a/src/uapi.go b/uapi.go
index 673d413..caaa498 100644
--- a/src/uapi.go
+++ b/uapi.go
@@ -25,32 +25,51 @@ func (s *IPCError) ErrorCode() int64 {
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
- // create lines
+ device.log.Debug.Println("UAPI: Processing get operation")
- device.mutex.RLock()
- device.net.mutex.RLock()
+ // create lines
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
}
- if !device.privateKey.IsZero() {
- send("private_key=" + device.privateKey.ToHex())
- }
+ func() {
- if device.net.port != 0 {
- send(fmt.Sprintf("listen_port=%d", device.net.port))
- }
+ // lock required resources
- if device.net.fwmark != 0 {
- send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
- }
+ device.net.mutex.RLock()
+ defer device.net.mutex.RUnlock()
+
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
+
+ device.routing.mutex.RLock()
+ defer device.routing.mutex.RUnlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ // serialize device related values
+
+ if !device.noise.privateKey.IsZero() {
+ send("private_key=" + device.noise.privateKey.ToHex())
+ }
+
+ if device.net.port != 0 {
+ send(fmt.Sprintf("listen_port=%d", device.net.port))
+ }
+
+ if device.net.fwmark != 0 {
+ send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
+ }
- for _, peer := range device.peers {
- func() {
+ // serialize each peer state
+
+ for _, peer := range device.peers.keyMap {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
+
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil {
@@ -69,16 +88,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
atomic.LoadUint64(&peer.persistentKeepaliveInterval),
))
- for _, ip := range device.routingTable.AllowedIPs(peer) {
+ for _, ip := range device.routing.table.AllowedIPs(peer) {
send("allowed_ip=" + ip.String())
}
- }()
- }
- device.net.mutex.RUnlock()
- device.mutex.RUnlock()
+ }
+ }()
- // send lines
+ // send lines (does not require resource locks)
for _, line := range lines {
_, err := socket.WriteString(line + "\n")
@@ -94,7 +111,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
scanner := bufio.NewScanner(socket)
- logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
@@ -130,16 +146,28 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalid}
}
+ logDebug.Println("UAPI: Updating device private key")
device.SetPrivateKey(sk)
case "listen_port":
+
+ // parse port number
+
port, err := strconv.ParseUint(value, 10, 16)
if err != nil {
logError.Println("Failed to parse listen_port:", err)
return &IPCError{Code: ipcErrorInvalid}
}
+
+ // update port and rebind
+
+ logDebug.Println("UAPI: Updating listen port")
+
+ device.net.mutex.Lock()
device.net.port = uint16(port)
- if err := updateBind(device); err != nil {
+ device.net.mutex.Unlock()
+
+ if err := device.BindUpdate(); err != nil {
logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorPortInUse}
}
@@ -161,15 +189,20 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
+ logDebug.Println("UAPI: Updating fwmark")
+
device.net.mutex.Lock()
device.net.fwmark = uint32(fwmark)
- if err := device.net.bind.SetMark(fwmark); err != nil {
+ device.net.mutex.Unlock()
+
+ if err := device.BindUpdate(); err != nil {
logError.Println("Failed to update fwmark:", err)
+ return &IPCError{Code: ipcErrorPortInUse}
}
- device.net.mutex.Unlock()
case "public_key":
// switch to peer configuration
+ logDebug.Println("UAPI: Transition to peer configuration")
deviceConfig = false
case "replace_peers":
@@ -177,6 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
+ logDebug.Println("UAPI: Removing all peers")
device.RemoveAllPeers()
default:
@@ -192,43 +226,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
switch key {
case "public_key":
- var pubKey NoisePublicKey
- err := pubKey.FromHex(value)
+ var publicKey NoisePublicKey
+ err := publicKey.FromHex(value)
if err != nil {
logError.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalid}
}
- // check if public key of peer equal to device
-
- device.mutex.RLock()
- if device.publicKey.Equals(pubKey) {
+ // ignore peer with public key of device
- // create dummy instance (not added to device)
+ device.noise.mutex.RLock()
+ equals := device.noise.publicKey.Equals(publicKey)
+ device.noise.mutex.RUnlock()
+ if equals {
peer = &Peer{}
dummy = true
- device.mutex.RUnlock()
- logInfo.Println("Ignoring peer with public key of device")
+ }
- } else {
+ // find peer referenced
- // find peer referenced
+ peer = device.LookupPeer(publicKey)
- peer, _ = device.peers[pubKey]
- device.mutex.RUnlock()
- if peer == nil {
- peer, err = device.NewPeer(pubKey)
- if err != nil {
- logError.Println("Failed to create new peer:", err)
- return &IPCError{Code: ipcErrorInvalid}
- }
+ if peer == nil {
+ peer, err = device.NewPeer(publicKey)
+ if err != nil {
+ logError.Println("Failed to create new peer:", err)
+ return &IPCError{Code: ipcErrorInvalid}
}
- peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
- dummy = false
-
+ logDebug.Println("UAPI: Created new peer:", peer.String())
}
+ peer.mutex.Lock()
+ peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
+ peer.mutex.Unlock()
+
case "remove":
// remove currently selected peer from device
@@ -238,7 +270,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
if !dummy {
- logDebug.Println("Removing", peer.String())
+ logDebug.Println("UAPI: Removing peer:", peer.String())
device.RemovePeer(peer.handshake.remoteStatic)
}
peer = &Peer{}
@@ -248,9 +280,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update PSK
- peer.mutex.Lock()
+ logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String())
+
+ peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
- peer.mutex.Unlock()
+ peer.handshake.mutex.Unlock()
+
if err != nil {
logError.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalid}
@@ -260,6 +295,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// set endpoint destination
+ logDebug.Println("UAPI: Updating endpoint for peer:", peer.String())
+
err := func() error {
peer.mutex.Lock()
defer peer.mutex.Unlock()
@@ -281,6 +318,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update keep-alive interval
+ logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String())
+
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
logError.Println("Failed to set persistent_keepalive_interval:", err)
@@ -299,31 +338,47 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to get tun device status:", err)
return &IPCError{Code: ipcErrorIO}
}
- if device.tun.isUp.Get() && !dummy {
+ if device.isUp.Get() && !dummy {
peer.SendKeepAlive()
}
}
case "replace_allowed_ips":
+
+ logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String())
+
if value != "true" {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
- if !dummy {
- device.routingTable.RemovePeer(peer)
+
+ if dummy {
+ continue
}
+ device.routing.mutex.Lock()
+ device.routing.table.RemovePeer(peer)
+ device.routing.mutex.Unlock()
+
case "allowed_ip":
+
+ logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String())
+
_, network, err := net.ParseCIDR(value)
if err != nil {
logError.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalid}
}
- ones, _ := network.Mask.Size()
- if !dummy {
- device.routingTable.Insert(network.IP, uint(ones), peer)
+
+ if dummy {
+ continue
}
+ ones, _ := network.Mask.Size()
+ device.routing.mutex.Lock()
+ device.routing.table.Insert(network.IP, uint(ones), peer)
+ device.routing.mutex.Unlock()
+
default:
logError.Println("Invalid UAPI key (peer configuration):", key)
return &IPCError{Code: ipcErrorInvalid}
diff --git a/src/uapi_darwin.go b/uapi_darwin.go
index 63d4d8d..63d4d8d 100644
--- a/src/uapi_darwin.go
+++ b/uapi_darwin.go
diff --git a/src/uapi_linux.go b/uapi_linux.go
index f97a18a..f97a18a 100644
--- a/src/uapi_linux.go
+++ b/uapi_linux.go
diff --git a/src/uapi_windows.go b/uapi_windows.go
index a4599a5..a4599a5 100644
--- a/src/uapi_windows.go
+++ b/uapi_windows.go
diff --git a/src/xchacha20.go b/xchacha20.go
index 5d963e0..5d963e0 100644
--- a/src/xchacha20.go
+++ b/xchacha20.go
diff --git a/src/xchacha20_test.go b/xchacha20_test.go
index 0f41cf8..0f41cf8 100644
--- a/src/xchacha20_test.go
+++ b/xchacha20_test.go