aboutsummaryrefslogtreecommitdiff
path: root/allowedips.go
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-03-03 04:04:41 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2019-03-03 05:00:40 +0100
commit69f0fe67b63d90e523a5a1241fb1b46c2e8dbe03 (patch)
tree1ef86da3242afde462dcadb7241bb09f499d5bd7 /allowedips.go
parentd435be35cac49af9367b2005d831d55e570c4b1b (diff)
downloadwireguard-go-69f0fe67b63d90e523a5a1241fb1b46c2e8dbe03.tar.gz
wireguard-go-69f0fe67b63d90e523a5a1241fb1b46c2e8dbe03.zip
global: begin modularization
Diffstat (limited to 'allowedips.go')
-rw-r--r--allowedips.go251
1 files changed, 0 insertions, 251 deletions
diff --git a/allowedips.go b/allowedips.go
deleted file mode 100644
index 2c4f601..0000000
--- a/allowedips.go
+++ /dev/null
@@ -1,251 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package main
-
-import (
- "errors"
- "math/bits"
- "net"
- "sync"
- "unsafe"
-)
-
-type trieEntry struct {
- cidr uint
- child [2]*trieEntry
- bits net.IP
- peer *Peer
-
- // index of "branching" bit
-
- bit_at_byte uint
- bit_at_shift uint
-}
-
-func isLittleEndian() bool {
- one := uint32(1)
- return *(*byte)(unsafe.Pointer(&one)) != 0
-}
-
-func swapU32(i uint32) uint32 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes32(i)
-}
-
-func swapU64(i uint64) uint64 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes64(i)
-}
-
-func commonBits(ip1 net.IP, ip2 net.IP) uint {
- size := len(ip1)
- if size == net.IPv4len {
- a := (*uint32)(unsafe.Pointer(&ip1[0]))
- b := (*uint32)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
- return uint(bits.LeadingZeros32(swapU32(x)))
- } else if size == net.IPv6len {
- a := (*uint64)(unsafe.Pointer(&ip1[0]))
- b := (*uint64)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
- if x != 0 {
- return uint(bits.LeadingZeros64(swapU64(x)))
- }
- a = (*uint64)(unsafe.Pointer(&ip1[8]))
- b = (*uint64)(unsafe.Pointer(&ip2[8]))
- x = *a ^ *b
- return 64 + uint(bits.LeadingZeros64(swapU64(x)))
- } else {
- panic("Wrong size bit string")
- }
-}
-
-func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
- if node == nil {
- return node
- }
-
- // walk recursively
-
- node.child[0] = node.child[0].removeByPeer(p)
- node.child[1] = node.child[1].removeByPeer(p)
-
- if node.peer != p {
- return node
- }
-
- // remove peer & merge
-
- node.peer = nil
- if node.child[0] == nil {
- return node.child[1]
- }
- return node.child[0]
-}
-
-func (node *trieEntry) choose(ip net.IP) byte {
- return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
-}
-
-func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
-
- // at leaf
-
- if node == nil {
- return &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
- }
- }
-
- // traverse deeper
-
- common := commonBits(node.bits, ip)
- if node.cidr <= cidr && common >= node.cidr {
- if node.cidr == cidr {
- node.peer = peer
- return node
- }
- bit := node.choose(ip)
- node.child[bit] = node.child[bit].insert(ip, cidr, peer)
- return node
- }
-
- // split node
-
- newNode := &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
- }
-
- cidr = min(cidr, common)
-
- // check for shorter prefix
-
- if newNode.cidr == cidr {
- bit := newNode.choose(node.bits)
- newNode.child[bit] = node
- return newNode
- }
-
- // create new parent for node & newNode
-
- parent := &trieEntry{
- bits: ip,
- peer: nil,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
- }
-
- bit := parent.choose(ip)
- parent.child[bit] = newNode
- parent.child[bit^1] = node
-
- return parent
-}
-
-func (node *trieEntry) lookup(ip net.IP) *Peer {
- var found *Peer
- size := uint(len(ip))
- for node != nil && commonBits(node.bits, ip) >= node.cidr {
- if node.peer != nil {
- found = node.peer
- }
- if node.bit_at_byte == size {
- break
- }
- bit := node.choose(ip)
- node = node.child[bit]
- }
- return found
-}
-
-func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
- if node == nil {
- return results
- }
- if node.peer == p {
- mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
- results = append(results, net.IPNet{
- Mask: mask,
- IP: node.bits.Mask(mask),
- })
- }
- results = node.child[0].entriesForPeer(p, results)
- results = node.child[1].entriesForPeer(p, results)
- return results
-}
-
-type AllowedIPs struct {
- IPv4 *trieEntry
- IPv6 *trieEntry
- mutex sync.RWMutex
-}
-
-func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
-
- allowed := make([]net.IPNet, 0, 10)
- allowed = table.IPv4.entriesForPeer(peer, allowed)
- allowed = table.IPv6.entriesForPeer(peer, allowed)
- return allowed
-}
-
-func (table *AllowedIPs) Reset() {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- table.IPv4 = nil
- table.IPv6 = nil
-}
-
-func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- table.IPv4 = table.IPv4.removeByPeer(peer)
- table.IPv6 = table.IPv6.removeByPeer(peer)
-}
-
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- switch len(ip) {
- case net.IPv6len:
- table.IPv6 = table.IPv6.insert(ip, cidr, peer)
- case net.IPv4len:
- table.IPv4 = table.IPv4.insert(ip, cidr, peer)
- default:
- panic(errors.New("inserting unknown address type"))
- }
-}
-
-func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
- return table.IPv4.lookup(address)
-}
-
-func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
- return table.IPv6.lookup(address)
-}