diff options
Diffstat (limited to 'device')
-rw-r--r-- | device/allowedips.go | 41 | ||||
-rw-r--r-- | device/allowedips_rand_test.go | 6 | ||||
-rw-r--r-- | device/allowedips_test.go | 6 | ||||
-rw-r--r-- | device/device_test.go | 6 | ||||
-rw-r--r-- | device/endpoint_test.go | 39 | ||||
-rw-r--r-- | device/receive.go | 1 | ||||
-rw-r--r-- | device/uapi.go | 11 |
7 files changed, 56 insertions, 54 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index c08399b..a6534b1 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -12,6 +12,8 @@ import ( "net" "sync" "unsafe" + + "golang.zx2c4.com/go118/netip" ) type parentIndirection struct { @@ -26,7 +28,7 @@ type trieEntry struct { cidr uint8 bitAtByte uint8 bitAtShift uint8 - bits net.IP + bits []byte perPeerElem *list.Element } @@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 { return bits.ReverseBytes64(i) } -func commonBits(ip1 net.IP, ip2 net.IP) uint8 { +func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { a := (*uint32)(unsafe.Pointer(&ip1[0])) @@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() { } } -func (node *trieEntry) choose(ip net.IP) byte { +func (node *trieEntry) choose(ip []byte) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } @@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() { node.parent.parentBit = nil } -func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) { +func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { parent = node if parent.cidr == cidr { @@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, return } -func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { +func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { if *trie.parentBit == nil { node := &trieEntry{ peer: peer, @@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) { } } -func (node *trieEntry) lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { @@ -229,13 +231,13 @@ type AllowedIPs struct { mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { node := elem.Value.(*trieEntry) - if !cb(node.bits, node.cidr) { + if !cb(netip.PrefixFrom(netip.AddrFromSlice(node.bits), int(node.cidr))) { return } } @@ -283,28 +285,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { } } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { +func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - switch len(ip) { - case net.IPv6len: - parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer) - case net.IPv4len: - parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer) - default: + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else { panic(errors.New("inserting unknown address type")) } } -func (table *AllowedIPs) Lookup(address []byte) *Peer { +func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - switch len(address) { + switch len(ip) { case net.IPv6len: - return table.IPv6.lookup(address) + return table.IPv6.lookup(ip) case net.IPv4len: - return table.IPv4.lookup(address) + return table.IPv4.lookup(ip) default: panic(errors.New("looking up unknown address type")) } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 16de170..ff56fe6 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -10,6 +10,8 @@ import ( "net" "sort" "testing" + + "golang.zx2c4.com/go118/netip" ) const ( @@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) { rand.Read(addr4[:]) cidr := uint8(rand.Intn(32) + 1) index := rand.Intn(NumberOfPeers) - allowedIPs.Insert(addr4[:], cidr, peers[index]) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) slow4 = slow4.Insert(addr4[:], cidr, peers[index]) var addr6 [16]byte rand.Read(addr6[:]) cidr = uint8(rand.Intn(128) + 1) index = rand.Intn(NumberOfPeers) - allowedIPs.Insert(addr6[:], cidr, peers[index]) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 2059a88..a274997 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -9,6 +9,8 @@ import ( "math/rand" "net" "testing" + + "golang.zx2c4.com/go118/netip" ) type testPairCommonBits struct { @@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) { var allowedIPs AllowedIPs insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { - allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { @@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - allowedIPs.Insert(addr, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { diff --git a/device/device_test.go b/device/device_test.go index 29daeb9..84221be 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -11,7 +11,6 @@ import ( "fmt" "io" "math/rand" - "net" "runtime" "runtime/pprof" "sync" @@ -19,6 +18,7 @@ import ( "testing" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/tun/tuntest" @@ -96,7 +96,7 @@ type testPair [2]testPeer type testPeer struct { tun *tuntest.ChannelTUN dev *Device - ip net.IP + ip netip.Addr } type SendDirection bool @@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { for i := range pair { p := &pair[i] p.tun = tuntest.NewChannelTUN() - p.ip = net.IPv4(1, 0, 0, byte(i+1)) + p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) level := LogLevelVerbose if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError diff --git a/device/endpoint_test.go b/device/endpoint_test.go index 57c361c..f1ae47e 100644 --- a/device/endpoint_test.go +++ b/device/endpoint_test.go @@ -7,47 +7,44 @@ package device import ( "math/rand" - "net" + + "golang.zx2c4.com/go118/netip" ) type DummyEndpoint struct { - src [16]byte - dst [16]byte + src, dst netip.Addr } func CreateDummyEndpoint() (*DummyEndpoint, error) { - var end DummyEndpoint - if _, err := rand.Read(end.src[:]); err != nil { + var src, dst [16]byte + if _, err := rand.Read(src[:]); err != nil { return nil, err } - _, err := rand.Read(end.dst[:]) - return &end, err + _, err := rand.Read(dst[:]) + return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err } func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) SrcToString() string { - var addr net.UDPAddr - addr.IP = e.SrcIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.SrcIP(), 1000).String() } func (e *DummyEndpoint) DstToString() string { - var addr net.UDPAddr - addr.IP = e.DstIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.DstIP(), 1000).String() } -func (e *DummyEndpoint) SrcToBytes() []byte { - return e.src[:] +func (e *DummyEndpoint) DstToBytes() []byte { + out := e.DstIP().AsSlice() + out = append(out, byte(1000&0xff)) + out = append(out, byte((1000>>8)&0xff)) + return out } -func (e *DummyEndpoint) DstIP() net.IP { - return e.dst[:] +func (e *DummyEndpoint) DstIP() netip.Addr { + return e.dst } -func (e *DummyEndpoint) SrcIP() net.IP { - return e.src[:] +func (e *DummyEndpoint) SrcIP() netip.Addr { + return e.src } diff --git a/device/receive.go b/device/receive.go index 5857481..cc34498 100644 --- a/device/receive.go +++ b/device/receive.go @@ -17,7 +17,6 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/conn" ) diff --git a/device/uapi.go b/device/uapi.go index 66ecd48..9572fb8 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/ipc" ) @@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) - device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool { - sendf("allowed_ip=%s/%d", ip.String(), cidr) + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) return true }) } @@ -370,16 +371,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error case "allowed_ip": device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) - - _, network, err := net.ParseCIDR(value) + prefix, err := netip.ParsePrefix(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) } if peer.dummy { return nil } - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint8(ones), peer.Peer) + device.allowedips.Insert(prefix, peer.Peer) case "protocol_version": if value != "1" { |