diff options
Diffstat (limited to 'device/device.go')
-rw-r--r-- | device/device.go | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/device/device.go b/device/device.go index 081d59f..a9fedea 100644 --- a/device/device.go +++ b/device/device.go @@ -17,7 +17,6 @@ import ( "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/wgcfg" ) type Device struct { @@ -47,13 +46,13 @@ type Device struct { staticIdentity struct { sync.RWMutex - privateKey wgcfg.PrivateKey - publicKey wgcfg.Key + privateKey NoisePrivateKey + publicKey NoisePublicKey } peers struct { sync.RWMutex - keyMap map[wgcfg.Key]*Peer + keyMap map[NoisePublicKey]*Peer } // unprotected / "self-synchronising resources" @@ -97,7 +96,7 @@ type Device struct { * * Must hold device.peers.Mutex */ -func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) { +func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { // stop routing and processing of packets @@ -201,13 +200,13 @@ func (device *Device) IsUnderLoad() bool { return until.After(now) } -func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // lock required resources device.staticIdentity.Lock() defer device.staticIdentity.Unlock() - if sk.Equal(device.staticIdentity.privateKey) { + if sk.Equals(device.staticIdentity.privateKey) { return nil } @@ -222,9 +221,9 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { // remove peers with matching public keys - publicKey := sk.Public() + publicKey := sk.publicKey() for key, peer := range device.peers.keyMap { - if peer.handshake.remoteStatic.Equal(publicKey) { + if peer.handshake.remoteStatic.Equals(publicKey) { unsafeRemovePeer(device, peer, key) } } @@ -240,7 +239,7 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) expiredPeers = append(expiredPeers, peer) } @@ -270,7 +269,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { } device.tun.mtu = int32(mtu) - device.peers.keyMap = make(map[wgcfg.Key]*Peer) + device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.rate.limiter.Init() device.rate.underLoadUntil.Store(time.Time{}) @@ -318,14 +317,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { return device } -func (device *Device) LookupPeer(pk wgcfg.Key) *Peer { +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() defer device.peers.RUnlock() return device.peers.keyMap[pk] } -func (device *Device) RemovePeer(key wgcfg.Key) { +func (device *Device) RemovePeer(key NoisePublicKey) { device.peers.Lock() defer device.peers.Unlock() // stop peer and remove from routing @@ -344,7 +343,7 @@ func (device *Device) RemoveAllPeers() { unsafeRemovePeer(device, peer, key) } - device.peers.keyMap = make(map[wgcfg.Key]*Peer) + device.peers.keyMap = make(map[NoisePublicKey]*Peer) } func (device *Device) FlushPacketQueues() { |