diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-11-05 01:52:54 +0100 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-11-23 22:03:15 +0100 |
commit | ef8d6804d77d9ce09f0e2c7f6d85bbe222712b73 (patch) | |
tree | 5b4a3b53dfb092f10cf11fbe0b5724f58df3a1bf /device/allowedips.go | |
parent | de7c702ace45b8eeba7f4de8ecd9c85c80806264 (diff) | |
download | wireguard-go-ef8d6804d77d9ce09f0e2c7f6d85bbe222712b73.tar.gz wireguard-go-ef8d6804d77d9ce09f0e2c7f6d85bbe222712b73.zip |
global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'device/allowedips.go')
-rw-r--r-- | device/allowedips.go | 42 |
1 files changed, 23 insertions, 19 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index c08399b..7a0b275 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -12,6 +12,8 @@ import ( "net" "sync" "unsafe" + + "golang.zx2c4.com/go118/netip" ) type parentIndirection struct { @@ -26,7 +28,7 @@ type trieEntry struct { cidr uint8 bitAtByte uint8 bitAtShift uint8 - bits net.IP + bits []byte perPeerElem *list.Element } @@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 { return bits.ReverseBytes64(i) } -func commonBits(ip1 net.IP, ip2 net.IP) uint8 { +func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { a := (*uint32)(unsafe.Pointer(&ip1[0])) @@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() { } } -func (node *trieEntry) choose(ip net.IP) byte { +func (node *trieEntry) choose(ip []byte) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } @@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() { node.parent.parentBit = nil } -func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { +func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { parent = node if parent.cidr == cidr { @@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, return } -func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { +func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { if *trie.parentBit == nil { node := &trieEntry{ peer: peer, @@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { } } -func (node *trieEntry) lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { @@ -229,13 +231,14 @@ type AllowedIPs struct { mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { node := elem.Value.(*trieEntry) - if !cb(node.bits, node.cidr) { + a, _ := netip.AddrFromSlice(node.bits) + if !cb(netip.PrefixFrom(a, int(node.cidr))) { return } } @@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { } } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { +func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - switch len(ip) { - case net.IPv6len: - parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) - case net.IPv4len: - parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) - default: + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else { panic(errors.New("inserting unknown address type")) } } -func (table *AllowedIPs) Lookup(address []byte) *Peer { +func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - switch len(address) { + switch len(ip) { case net.IPv6len: - return table.IPv6.lookup(address) + return table.IPv6.lookup(ip) case net.IPv4len: - return table.IPv4.lookup(address) + return table.IPv4.lookup(ip) default: panic(errors.New("looking up unknown address type")) } |