diff options
Diffstat (limited to 'device/device.go')
-rw-r--r-- | device/device.go | 279 |
1 files changed, 151 insertions, 128 deletions
diff --git a/device/device.go b/device/device.go index 9ea7c24..c637e38 100644 --- a/device/device.go +++ b/device/device.go @@ -21,17 +21,26 @@ import ( ) type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger + log *Logger // synchronized resources (locks acquired in order) state struct { + // state holds the device's state. It is accessed atomically. + // Use the device.deviceState method to read it. + // If state.mu is (r)locked, state is the current state of the device. + // Without state.mu (r)locked, state is either the current state + // of the device or the intended future state of the device. + // For example, while executing a call to Up, state will be deviceStateUp. + // There is no guarantee that that intended future state of the device + // will become the actual state; Up can fail. + // The device can also change state multiple times between time of check and time of use. + // Unsynchronized uses of state must therefore be advisory/best-effort only. + state uint32 // actually a deviceState, but typed uint32 for conveniene + // stopping blocks until all inputs to Device have been closed. stopping sync.WaitGroup - sync.Mutex - changing AtomicBool - current bool + // mu protects state changes. + mu sync.Mutex } net struct { @@ -87,6 +96,43 @@ type Device struct { closed chan struct{} } +// deviceState represents the state of a Device. +// There are four states: new, down, up, closed. +// However, state new should never be observable. +// Transitions: +// +// new -> down -----+ +// ↑↓ ↓ +// up -> closed +// +type deviceState uint32 + +//go:generate stringer -type deviceState -trimprefix=deviceState +const ( + deviceStateNew deviceState = iota + deviceStateDown + deviceStateUp + deviceStateClosed +) + +// deviceState returns device.state.state as a deviceState +// See those docs for how to interpret this value. +func (device *Device) deviceState() deviceState { + return deviceState(atomic.LoadUint32(&device.state.state)) +} + +// isClosed reports whether the device is closed (or is closing). +// See device.state.state comments for how to interpret this value. +func (device *Device) isClosed() bool { + return device.deviceState() == deviceStateClosed +} + +// isUp reports whether the device is up (or is attempting to come up). +// See device.state.state comments for how to interpret this value. +func (device *Device) isUp() bool { + return device.deviceState() == deviceStateUp +} + // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. // An outboundQueue is ref-counted using its wg field. // An outboundQueue created with newOutboundQueue has one reference. @@ -154,91 +200,82 @@ func newHandshakeQueue() *handshakeQueue { * Must hold device.peers.Mutex */ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { - // stop routing and processing of packets - device.allowedips.RemoveByPeer(peer) peer.Stop() // remove from peer map - delete(device.peers.keyMap, key) device.peers.empty.Set(len(device.peers.keyMap) == 0) } -func deviceUpdateState(device *Device) { - - // check if state already being updated (guard) - - if device.state.changing.Swap(true) { +// changeState attempts to change the device state to match want. +func (device *Device) changeState(want deviceState) { + device.state.mu.Lock() + defer device.state.mu.Unlock() + old := device.deviceState() + if old == deviceStateClosed { + // once closed, always closed + device.log.Verbosef("Interface closed, ignored requested state %s", want) return } - - // compare to current state of device - - device.state.Lock() - - newIsUp := device.isUp.Get() - - if newIsUp == device.state.current { - device.state.changing.Set(false) - device.state.Unlock() + switch want { + case old: + device.log.Verbosef("Interface already in state %s", want) return - } - - // change state of device - - switch newIsUp { - case true: - if err := device.BindUpdate(); err != nil { - device.log.Errorf("Unable to update bind: %v", err) - device.isUp.Set(false) + case deviceStateUp: + atomic.StoreUint32(&device.state.state, uint32(deviceStateUp)) + if ok := device.upLocked(); ok { break } - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Start() - if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { - peer.SendKeepalive() - } - } - device.peers.RUnlock() - - case false: - device.BindClose() - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Stop() - } - device.peers.RUnlock() + fallthrough // up failed; bring the device all the way back down + case deviceStateDown: + atomic.StoreUint32(&device.state.state, uint32(deviceStateDown)) + device.downLocked() } + device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) +} - // update state variables - - device.state.current = newIsUp - device.state.changing.Set(false) - device.state.Unlock() - - // check for state change in the mean time +// upLocked attempts to bring the device up and reports whether it succeeded. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) upLocked() bool { + if err := device.BindUpdate(); err != nil { + device.log.Errorf("Unable to update bind: %v", err) + return false + } - deviceUpdateState(device) + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Start() + if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 { + peer.SendKeepalive() + } + } + device.peers.RUnlock() + return true } -func (device *Device) Up() { - - // closed device cannot be brought up +// downLocked attempts to bring the device down. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) downLocked() { + err := device.BindClose() + if err != nil { + device.log.Errorf("Bind close failed: %v", err) + } - if device.isClosed.Get() { - return + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Stop() } + device.peers.RUnlock() +} - device.isUp.Set(true) - deviceUpdateState(device) +func (device *Device) Up() { + device.changeState(deviceStateUp) } func (device *Device) Down() { - device.isUp.Set(false) - deviceUpdateState(device) + device.changeState(deviceStateDown) } func (device *Device) IsUnderLoad() bool { @@ -310,6 +347,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func NewDevice(tunDevice tun.Device, logger *Logger) *Device { device := new(Device) + device.state.state = uint32(deviceStateDown) device.closed = make(chan struct{}) device.log = logger device.tun.device = tunDevice @@ -382,19 +420,16 @@ func (device *Device) RemoveAllPeers() { } func (device *Device) Close() { - if device.isClosed.Swap(true) { + device.state.mu.Lock() + defer device.state.mu.Unlock() + if device.isClosed() { return } - + atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed)) device.log.Verbosef("Device closing") - device.state.changing.Set(true) - device.state.Lock() - defer device.state.Unlock() device.tun.device.Close() - device.BindClose() - - device.isUp.Set(false) + device.downLocked() // Remove peers before closing queues, // because peers assume that queues are active. @@ -410,8 +445,7 @@ func (device *Device) Close() { device.rate.limiter.Close() - device.state.changing.Set(false) - device.log.Verbosef("Interface closed") + device.log.Verbosef("Device closed") close(device.closed) } @@ -420,7 +454,7 @@ func (device *Device) Wait() chan struct{} { } func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { - if device.isClosed.Get() { + if !device.isUp() { return } @@ -457,27 +491,23 @@ func (device *Device) Bind() conn.Bind { } func (device *Device) BindSetMark(mark uint32) error { - device.net.Lock() defer device.net.Unlock() // check if modified - if device.net.fwmark == mark { return nil } // update fwmark on existing bind - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { + if device.isUp() && device.net.bind != nil { if err := device.net.bind.SetMark(mark); err != nil { return err } } // clear cached source addresses - device.peers.RLock() for _, peer := range device.peers.keyMap { peer.Lock() @@ -492,70 +522,63 @@ func (device *Device) BindSetMark(mark uint32) error { } func (device *Device) BindUpdate() error { - device.net.Lock() defer device.net.Unlock() // close existing sockets - if err := unsafeCloseBind(device); err != nil { return err } // open new sockets + if !device.isUp() { + return nil + } - if device.isUp.Get() { - - // bind to new port + // bind to new port + var err error + netc := &device.net + netc.bind, netc.port, err = conn.CreateBind(netc.port) + if err != nil { + netc.bind = nil + netc.port = 0 + return err + } + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.bind = nil + netc.port = 0 + return err + } - var err error - netc := &device.net - netc.bind, netc.port, err = conn.CreateBind(netc.port) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - netc.netlinkCancel, err = device.startRouteListener(netc.bind) + // set fwmark + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) if err != nil { - netc.bind.Close() - netc.bind = nil - netc.port = 0 return err } + } - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + // clear cached source addresses + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() } - device.peers.RUnlock() - - // start receiving routines - - device.net.stopping.Add(2) - device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption - device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - - device.log.Verbosef("UDP bind has been updated") } + device.peers.RUnlock() + + // start receiving routines + device.net.stopping.Add(2) + device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.log.Verbosef("UDP bind has been updated") return nil } |