diff options
Diffstat (limited to 'device/device.go')
-rw-r--r-- | device/device.go | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/device/device.go b/device/device.go index a9fedea..081d59f 100644 --- a/device/device.go +++ b/device/device.go @@ -17,6 +17,7 @@ import ( "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/wgcfg" ) type Device struct { @@ -46,13 +47,13 @@ type Device struct { staticIdentity struct { sync.RWMutex - privateKey NoisePrivateKey - publicKey NoisePublicKey + privateKey wgcfg.PrivateKey + publicKey wgcfg.Key } peers struct { sync.RWMutex - keyMap map[NoisePublicKey]*Peer + keyMap map[wgcfg.Key]*Peer } // unprotected / "self-synchronising resources" @@ -96,7 +97,7 @@ type Device struct { * * Must hold device.peers.Mutex */ -func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { +func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) { // stop routing and processing of packets @@ -200,13 +201,13 @@ func (device *Device) IsUnderLoad() bool { return until.After(now) } -func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { +func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { // lock required resources device.staticIdentity.Lock() defer device.staticIdentity.Unlock() - if sk.Equals(device.staticIdentity.privateKey) { + if sk.Equal(device.staticIdentity.privateKey) { return nil } @@ -221,9 +222,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // remove peers with matching public keys - publicKey := sk.publicKey() + publicKey := sk.Public() for key, peer := range device.peers.keyMap { - if peer.handshake.remoteStatic.Equals(publicKey) { + if peer.handshake.remoteStatic.Equal(publicKey) { unsafeRemovePeer(device, peer, key) } } @@ -239,7 +240,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic) expiredPeers = append(expiredPeers, peer) } @@ -269,7 +270,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { } device.tun.mtu = int32(mtu) - device.peers.keyMap = make(map[NoisePublicKey]*Peer) + device.peers.keyMap = make(map[wgcfg.Key]*Peer) device.rate.limiter.Init() device.rate.underLoadUntil.Store(time.Time{}) @@ -317,14 +318,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { return device } -func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { +func (device *Device) LookupPeer(pk wgcfg.Key) *Peer { device.peers.RLock() defer device.peers.RUnlock() return device.peers.keyMap[pk] } -func (device *Device) RemovePeer(key NoisePublicKey) { +func (device *Device) RemovePeer(key wgcfg.Key) { device.peers.Lock() defer device.peers.Unlock() // stop peer and remove from routing @@ -343,7 +344,7 @@ func (device *Device) RemoveAllPeers() { unsafeRemovePeer(device, peer, key) } - device.peers.keyMap = make(map[NoisePublicKey]*Peer) + device.peers.keyMap = make(map[wgcfg.Key]*Peer) } func (device *Device) FlushPacketQueues() { |