diff options
author | Tyler Kropp <kropptyler@gmail.com> | 2020-03-02 19:41:28 -0500 |
---|---|---|
committer | David Crawshaw <david@zentus.com> | 2020-03-31 09:32:57 +1100 |
commit | c7bb15a70df5cfc949c836429b5e39ce57d047f9 (patch) | |
tree | 09d98c2464f5af1bcf58d41ef7f1fd79f7d3d769 | |
parent | a38504e3994268ee9aee5cb1d42bab82ef5ede76 (diff) | |
download | wireguard-go-c7bb15a70df5cfc949c836429b5e39ce57d047f9.tar.gz wireguard-go-c7bb15a70df5cfc949c836429b5e39ce57d047f9.zip |
wgcfg: add fast CIDR.Contains implementation
Signed-off-by: Tyler Kropp <kropptyler@gmail.com>
-rw-r--r-- | wgcfg/ip.go | 26 | ||||
-rw-r--r-- | wgcfg/ip_test.go | 118 |
2 files changed, 142 insertions, 2 deletions
diff --git a/wgcfg/ip.go b/wgcfg/ip.go index ecf5faf..7541d18 100644 --- a/wgcfg/ip.go +++ b/wgcfg/ip.go @@ -2,6 +2,7 @@ package wgcfg import ( "fmt" + "math" "net" ) @@ -106,12 +107,33 @@ func (r *CIDR) IPNet() *net.IPNet { } return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)} } + func (r *CIDR) Contains(ip *IP) bool { if r == nil || ip == nil { return false } - // TODO: this isn't hard, write a more efficient implementation. - return r.IPNet().Contains(ip.IP()) + c := int8(r.Mask) + i := 0 + if r.IP.Is4() { + i = 12 + if ip.Is6() { + return false + } + } + for ; i < 16 && c > 0; i++ { + var x uint8 + if c < 8 { + x = 8 - uint8(c) + } + m := uint8(math.MaxUint8) >> x << x + a := r.IP.Addr[i] & m + b := ip.Addr[i] & m + if a != b { + return false + } + c -= 8 + } + return true } func (r CIDR) MarshalText() ([]byte, error) { diff --git a/wgcfg/ip_test.go b/wgcfg/ip_test.go new file mode 100644 index 0000000..d3682bb --- /dev/null +++ b/wgcfg/ip_test.go @@ -0,0 +1,118 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgcfg_test + +import ( + "testing" + + "golang.zx2c4.com/wireguard/wgcfg" +) + +func TestCIDRContains(t *testing.T) { + t.Run("home router test", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("192.168.0.0/24") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("192.168.0.1") + if ip == nil { + t.Fatalf("address failed to parse") + } + if !r.Contains(ip) { + t.Fatalf("'%s' should contain '%s'", r, ip) + } + }) + + t.Run("IPv4 outside network", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("192.168.0.0/30") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("192.168.0.4") + if ip == nil { + t.Fatalf("address failed to parse") + } + if r.Contains(ip) { + t.Fatalf("'%s' should not contain '%s'", r, ip) + } + }) + + t.Run("IPv4 does not contain IPv6", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("192.168.0.0/24") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("2001:db8:85a3:0:0:8a2e:370:7334") + if ip == nil { + t.Fatalf("address failed to parse") + } + if r.Contains(ip) { + t.Fatalf("'%s' should not contain '%s'", r, ip) + } + }) + + t.Run("IPv6 inside network", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("2001:db8:1234::/48") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001") + if ip == nil { + t.Fatalf("ParseIP returned nil pointer") + } + if !r.Contains(ip) { + t.Fatalf("'%s' should not contain '%s'", r, ip) + } + }) + + t.Run("IPv6 outside network", func(t *testing.T) { + r, err := wgcfg.ParseCIDR("2001:db8:1234:0:190b:0:1982::/126") + if err != nil { + t.Fatal(err) + } + ip := wgcfg.ParseIP("2001:db8:1234:0:190b:0:1982:4") + if ip == nil { + t.Fatalf("ParseIP returned nil pointer") + } + if r.Contains(ip) { + t.Fatalf("'%s' should not contain '%s'", r, ip) + } + }) +} + +func BenchmarkCIDRContainsIPv4(b *testing.B) { + b.Run("IPv4", func(b *testing.B) { + r, err := wgcfg.ParseCIDR("192.168.1.0/24") + if err != nil { + b.Fatal(err) + } + ip := wgcfg.ParseIP("1.2.3.4") + if ip == nil { + b.Fatalf("ParseIP returned nil pointer") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.Contains(ip) + } + }) + + b.Run("IPv6", func(b *testing.B) { + r, err := wgcfg.ParseCIDR("2001:db8:1234::/48") + if err != nil { + b.Fatal(err) + } + ip := wgcfg.ParseIP("2001:db8:1234:0000:0000:0000:0000:0001") + if ip == nil { + b.Fatalf("ParseIP returned nil pointer") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r.Contains(ip) + } + }) +} |