summaryrefslogtreecommitdiff
path: root/drivers
diff options
context:
space:
mode:
Diffstat (limited to 'drivers')
-rw-r--r--drivers/net/wireguard/allowedips.c87
-rw-r--r--drivers/net/wireguard/allowedips.h14
-rw-r--r--drivers/net/wireguard/device.c2
-rw-r--r--drivers/net/wireguard/netlink.c15
-rw-r--r--drivers/net/wireguard/peer.c5
-rw-r--r--drivers/net/wireguard/selftest/allowedips.c106
6 files changed, 128 insertions, 101 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c
index 3725e9c..c52782b 100644
--- a/drivers/net/wireguard/allowedips.c
+++ b/drivers/net/wireguard/allowedips.c
@@ -66,8 +66,7 @@ static void root_remove_peer_lists(struct allowedips_node *root)
}
}
-static void walk_remove_by_peer(struct allowedips_node __rcu **top,
- struct wg_peer *peer, struct mutex *lock)
+static void walk_cleanup_empty_nodes(struct allowedips_node __rcu **top, struct mutex *lock)
{
#define REF(p) rcu_access_pointer(p)
#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
@@ -80,7 +79,7 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top,
struct allowedips_node *node, *prev;
unsigned int len;
- if (unlikely(!peer || !REF(*top)))
+ if (unlikely(!REF(*top)))
return;
for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
@@ -100,16 +99,11 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top,
if (REF(node->bit[1]))
PUSH(&node->bit[1]);
} else {
- if (rcu_dereference_protected(node->peer,
- lockdep_is_held(lock)) == peer) {
- RCU_INIT_POINTER(node->peer, NULL);
- list_del_init(&node->peer_list);
- if (!node->bit[0] || !node->bit[1]) {
- rcu_assign_pointer(*nptr, DEREF(
- &node->bit[!REF(node->bit[0])]));
- kfree_rcu(node, rcu);
- node = DEREF(nptr);
- }
+ if (!rcu_dereference_protected(node->peer, lockdep_is_held(lock)) &&
+ (!node->bit[0] || !node->bit[1])) {
+ rcu_assign_pointer(*nptr, DEREF(&node->bit[!REF(node->bit[0])]));
+ kfree_rcu(node, rcu);
+ node = DEREF(nptr);
}
--len;
}
@@ -281,29 +275,54 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return 0;
}
-void wg_allowedips_init(struct allowedips *table)
+static void wg_allowedips_cleanup(struct work_struct *work)
+{
+ struct delayed_work *dwork = to_delayed_work(work);
+ struct allowedips *table = container_of(dwork, struct allowedips, cleanup_work);
+
+ if (!mutex_trylock(table->update_lock)) {
+ queue_delayed_work(system_power_efficient_wq, dwork, HZ / 4);
+ return;
+ }
+ if (table->dirty4) {
+ walk_cleanup_empty_nodes(&table->root4, table->update_lock);
+ table->dirty4 = false;
+ }
+ if (table->dirty6) {
+ walk_cleanup_empty_nodes(&table->root6, table->update_lock);
+ table->dirty6 = false;
+ }
+ mutex_unlock(table->update_lock);
+}
+
+void wg_allowedips_init(struct allowedips *table, struct mutex *update_lock)
{
table->root4 = table->root6 = NULL;
+ table->update_lock = update_lock;
table->seq = 1;
+ table->dirty4 = table->dirty6 = false;
+ INIT_DEFERRABLE_WORK(&table->cleanup_work, wg_allowedips_cleanup);
}
-void wg_allowedips_free(struct allowedips *table, struct mutex *lock)
+void wg_allowedips_free(struct allowedips *table)
{
struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6;
++table->seq;
+ table->dirty4 = table->dirty6 = false;
+ cancel_delayed_work_sync(&table->cleanup_work);
RCU_INIT_POINTER(table->root4, NULL);
RCU_INIT_POINTER(table->root6, NULL);
if (rcu_access_pointer(old4)) {
struct allowedips_node *node = rcu_dereference_protected(old4,
- lockdep_is_held(lock));
+ lockdep_is_held(table->update_lock));
root_remove_peer_lists(node);
call_rcu(&node->rcu, root_free_rcu);
}
if (rcu_access_pointer(old6)) {
struct allowedips_node *node = rcu_dereference_protected(old6,
- lockdep_is_held(lock));
+ lockdep_is_held(table->update_lock));
root_remove_peer_lists(node);
call_rcu(&node->rcu, root_free_rcu);
@@ -311,33 +330,51 @@ void wg_allowedips_free(struct allowedips *table, struct mutex *lock)
}
int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
- u8 cidr, struct wg_peer *peer, struct mutex *lock)
+ u8 cidr, struct wg_peer *peer)
{
/* Aligned so it can be passed to fls */
u8 key[4] __aligned(__alignof(u32));
++table->seq;
+ if (table->dirty4) {
+ walk_cleanup_empty_nodes(&table->root4, table->update_lock);
+ table->dirty4 = false;
+ }
swap_endian(key, (const u8 *)ip, 32);
- return add(&table->root4, 32, key, cidr, peer, lock);
+ return add(&table->root4, 32, key, cidr, peer, table->update_lock);
}
int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
- u8 cidr, struct wg_peer *peer, struct mutex *lock)
+ u8 cidr, struct wg_peer *peer)
{
/* Aligned so it can be passed to fls64 */
u8 key[16] __aligned(__alignof(u64));
++table->seq;
+ if (table->dirty6) {
+ walk_cleanup_empty_nodes(&table->root6, table->update_lock);
+ table->dirty6 = false;
+ }
swap_endian(key, (const u8 *)ip, 128);
- return add(&table->root6, 128, key, cidr, peer, lock);
+ return add(&table->root6, 128, key, cidr, peer, table->update_lock);
}
-void wg_allowedips_remove_by_peer(struct allowedips *table,
- struct wg_peer *peer, struct mutex *lock)
+void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer)
{
+ struct allowedips_node *node, *tmp;
+
+ lockdep_assert_held(table->update_lock);
++table->seq;
- walk_remove_by_peer(&table->root4, peer, lock);
- walk_remove_by_peer(&table->root6, peer, lock);
+
+ list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
+ list_del_init(&node->peer_list);
+ RCU_INIT_POINTER(node->peer, NULL);
+ table->dirty4 |= node->bitlen == 32;
+ table->dirty6 |= node->bitlen == 128;
+ }
+
+ if (table->dirty4 || table->dirty6)
+ queue_delayed_work(system_power_efficient_wq, &table->cleanup_work, HZ);
}
int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)
diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h
index e5c83ca..602c334 100644
--- a/drivers/net/wireguard/allowedips.h
+++ b/drivers/net/wireguard/allowedips.h
@@ -32,17 +32,19 @@ struct allowedips_node {
struct allowedips {
struct allowedips_node __rcu *root4;
struct allowedips_node __rcu *root6;
+ struct mutex *update_lock;
+ struct delayed_work cleanup_work;
u64 seq;
+ bool dirty4, dirty6;
};
-void wg_allowedips_init(struct allowedips *table);
-void wg_allowedips_free(struct allowedips *table, struct mutex *mutex);
+void wg_allowedips_init(struct allowedips *table, struct mutex *update_lock);
+void wg_allowedips_free(struct allowedips *table);
int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
- u8 cidr, struct wg_peer *peer, struct mutex *lock);
+ u8 cidr, struct wg_peer *peer);
int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
- u8 cidr, struct wg_peer *peer, struct mutex *lock);
-void wg_allowedips_remove_by_peer(struct allowedips *table,
- struct wg_peer *peer, struct mutex *lock);
+ u8 cidr, struct wg_peer *peer);
+void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer);
/* The ip input pointer should be __aligned(__alignof(u64))) */
int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr);
diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index 551ddaa..610ce1d 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -299,7 +299,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
mutex_init(&wg->socket_update_lock);
mutex_init(&wg->device_update_lock);
skb_queue_head_init(&wg->incoming_handshakes);
- wg_allowedips_init(&wg->peer_allowedips);
+ wg_allowedips_init(&wg->peer_allowedips, &wg->device_update_lock);
wg_cookie_checker_init(&wg->cookie_checker, wg);
INIT_LIST_HEAD(&wg->peer_list);
wg->device_update_gen = 1;
diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c
index d0f3b6d..9dcbc57 100644
--- a/drivers/net/wireguard/netlink.c
+++ b/drivers/net/wireguard/netlink.c
@@ -340,16 +340,12 @@ static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs)
if (family == AF_INET && cidr <= 32 &&
nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
- ret = wg_allowedips_insert_v4(
- &peer->device->peer_allowedips,
- nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
- &peer->device->device_update_lock);
+ ret = wg_allowedips_insert_v4(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer);
else if (family == AF_INET6 && cidr <= 128 &&
nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr))
- ret = wg_allowedips_insert_v6(
- &peer->device->peer_allowedips,
- nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
- &peer->device->device_update_lock);
+ ret = wg_allowedips_insert_v6(&peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer);
return ret;
}
@@ -449,8 +445,7 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs)
}
if (flags & WGPEER_F_REPLACE_ALLOWEDIPS)
- wg_allowedips_remove_by_peer(&wg->peer_allowedips, peer,
- &wg->device_update_lock);
+ wg_allowedips_remove_by_peer(&wg->peer_allowedips, peer);
if (attrs[WGPEER_A_ALLOWEDIPS]) {
struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1];
diff --git a/drivers/net/wireguard/peer.c b/drivers/net/wireguard/peer.c
index 3a042d2..3a14a37 100644
--- a/drivers/net/wireguard/peer.c
+++ b/drivers/net/wireguard/peer.c
@@ -81,8 +81,7 @@ static void peer_make_dead(struct wg_peer *peer)
{
/* Remove from configuration-time lookup structures. */
list_del_init(&peer->peer_list);
- wg_allowedips_remove_by_peer(&peer->device->peer_allowedips, peer,
- &peer->device->device_update_lock);
+ wg_allowedips_remove_by_peer(&peer->device->peer_allowedips, peer);
wg_pubkey_hashtable_remove(peer->device->peer_hashtable, peer);
/* Mark as dead, so that we don't allow jumping contexts after. */
@@ -172,7 +171,7 @@ void wg_peer_remove_all(struct wg_device *wg)
lockdep_assert_held(&wg->device_update_lock);
/* Avoid having to traverse individually for each one. */
- wg_allowedips_free(&wg->peer_allowedips, &wg->device_update_lock);
+ wg_allowedips_free(&wg->peer_allowedips);
list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) {
peer_make_dead(peer);
diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c
index 0d2a43a..61202d5 100644
--- a/drivers/net/wireguard/selftest/allowedips.c
+++ b/drivers/net/wireguard/selftest/allowedips.c
@@ -273,22 +273,21 @@ static __init bool randomized_test(void)
{
unsigned int i, j, k, mutate_amount, cidr;
u8 ip[16], mutate_mask[16], mutated[16];
- struct wg_peer **peers, *peer;
+ struct wg_peer **peers = kcalloc(NUM_PEERS, sizeof(*peers), GFP_KERNEL), *peer;
struct horrible_allowedips h;
DEFINE_MUTEX(mutex);
- struct allowedips t;
+ struct allowedips *t = kmalloc(sizeof(*t), GFP_KERNEL);
bool ret = false;
- mutex_init(&mutex);
-
- wg_allowedips_init(&t);
- horrible_allowedips_init(&h);
-
- peers = kcalloc(NUM_PEERS, sizeof(*peers), GFP_KERNEL);
- if (unlikely(!peers)) {
+ if (unlikely(!t || !peers)) {
pr_err("allowedips random self-test malloc: FAIL\n");
goto free;
}
+
+ mutex_init(&mutex);
+ wg_allowedips_init(t, &mutex);
+ horrible_allowedips_init(&h);
+
for (i = 0; i < NUM_PEERS; ++i) {
peers[i] = kzalloc(sizeof(*peers[i]), GFP_KERNEL);
if (unlikely(!peers[i])) {
@@ -305,13 +304,11 @@ static __init bool randomized_test(void)
prandom_bytes(ip, 4);
cidr = prandom_u32_max(32) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
- if (wg_allowedips_insert_v4(&t, (struct in_addr *)ip, cidr,
- peer, &mutex) < 0) {
+ if (wg_allowedips_insert_v4(t, (struct in_addr *)ip, cidr, peer) < 0) {
pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked;
}
- if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip,
- cidr, peer) < 0) {
+ if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip, cidr, peer) < 0) {
pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked;
}
@@ -331,14 +328,13 @@ static __init bool randomized_test(void)
prandom_u32_max(256));
cidr = prandom_u32_max(32) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
- if (wg_allowedips_insert_v4(&t,
- (struct in_addr *)mutated,
- cidr, peer, &mutex) < 0) {
+ if (wg_allowedips_insert_v4(t, (struct in_addr *)mutated,
+ cidr, peer) < 0) {
pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked;
}
- if (horrible_allowedips_insert_v4(&h,
- (struct in_addr *)mutated, cidr, peer)) {
+ if (horrible_allowedips_insert_v4(&h, (struct in_addr *)mutated,
+ cidr, peer)) {
pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked;
}
@@ -349,13 +345,11 @@ static __init bool randomized_test(void)
prandom_bytes(ip, 16);
cidr = prandom_u32_max(128) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
- if (wg_allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr,
- peer, &mutex) < 0) {
+ if (wg_allowedips_insert_v6(t, (struct in6_addr *)ip, cidr, peer) < 0) {
pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked;
}
- if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip,
- cidr, peer) < 0) {
+ if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip, cidr, peer) < 0) {
pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked;
}
@@ -375,15 +369,13 @@ static __init bool randomized_test(void)
prandom_u32_max(256));
cidr = prandom_u32_max(128) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
- if (wg_allowedips_insert_v6(&t,
- (struct in6_addr *)mutated,
- cidr, peer, &mutex) < 0) {
+ if (wg_allowedips_insert_v6(t, (struct in6_addr *)mutated,
+ cidr, peer) < 0) {
pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked;
}
- if (horrible_allowedips_insert_v6(
- &h, (struct in6_addr *)mutated, cidr,
- peer)) {
+ if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)mutated,
+ cidr, peer)) {
pr_err("allowedips random self-test malloc: FAIL\n");
goto free_locked;
}
@@ -393,13 +385,13 @@ static __init bool randomized_test(void)
mutex_unlock(&mutex);
if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
- print_tree(t.root4, 32);
- print_tree(t.root6, 128);
+ print_tree(t->root4, 32);
+ print_tree(t->root6, 128);
}
for (i = 0; i < NUM_QUERIES; ++i) {
prandom_bytes(ip, 4);
- if (lookup(t.root4, 32, ip) !=
+ if (lookup(t->root4, 32, ip) !=
horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
pr_err("allowedips random self-test: FAIL\n");
goto free;
@@ -408,7 +400,7 @@ static __init bool randomized_test(void)
for (i = 0; i < NUM_QUERIES; ++i) {
prandom_bytes(ip, 16);
- if (lookup(t.root6, 128, ip) !=
+ if (lookup(t->root6, 128, ip) !=
horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
pr_err("allowedips random self-test: FAIL\n");
goto free;
@@ -419,7 +411,8 @@ static __init bool randomized_test(void)
free:
mutex_lock(&mutex);
free_locked:
- wg_allowedips_free(&t, &mutex);
+ wg_allowedips_free(t);
+ kfree(t);
mutex_unlock(&mutex);
horrible_allowedips_free(&h);
if (peers) {
@@ -466,8 +459,8 @@ static __init struct wg_peer *init_peer(void)
}
#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \
- wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
- cidr, mem, &mutex)
+ wg_allowedips_insert_v##version(t, ip##version(ipa, ipb, ipc, ipd), \
+ cidr, mem)
#define maybe_fail() do { \
++i; \
@@ -477,16 +470,16 @@ static __init struct wg_peer *init_peer(void)
} \
} while (0)
-#define test(version, mem, ipa, ipb, ipc, ipd) do { \
- bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
- ip##version(ipa, ipb, ipc, ipd)) == (mem); \
- maybe_fail(); \
+#define test(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t->root##version, (version) == 4 ? 32 : 128, \
+ ip##version(ipa, ipb, ipc, ipd)) == (mem); \
+ maybe_fail(); \
} while (0)
-#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \
- bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \
- ip##version(ipa, ipb, ipc, ipd)) != (mem); \
- maybe_fail(); \
+#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t->root##version, (version) == 4 ? 32 : 128, \
+ ip##version(ipa, ipb, ipc, ipd)) != (mem); \
+ maybe_fail(); \
} while (0)
#define test_boolean(cond) do { \
@@ -503,7 +496,7 @@ bool __init wg_allowedips_selftest(void)
*g = init_peer(), *h = init_peer();
struct allowedips_node *iter_node;
bool success = false;
- struct allowedips t;
+ struct allowedips *t = kmalloc(sizeof(*t), GFP_KERNEL);
DEFINE_MUTEX(mutex);
struct in6_addr ip;
size_t i = 0, count = 0;
@@ -511,9 +504,9 @@ bool __init wg_allowedips_selftest(void)
mutex_init(&mutex);
mutex_lock(&mutex);
- wg_allowedips_init(&t);
+ wg_allowedips_init(t, &mutex);
- if (!a || !b || !c || !d || !e || !f || !g || !h) {
+ if (!a || !b || !c || !d || !e || !f || !g || !h || !t) {
pr_err("allowedips self-test malloc: FAIL\n");
goto free;
}
@@ -547,8 +540,8 @@ bool __init wg_allowedips_selftest(void)
insert(4, d, 10, 1, 0, 16, 29);
if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) {
- print_tree(t.root4, 32);
- print_tree(t.root6, 128);
+ print_tree(t->root4, 32);
+ print_tree(t->root6, 128);
}
success = true;
@@ -587,18 +580,18 @@ bool __init wg_allowedips_selftest(void)
insert(4, a, 128, 0, 0, 0, 32);
insert(4, a, 192, 0, 0, 0, 32);
insert(4, a, 255, 0, 0, 0, 32);
- wg_allowedips_remove_by_peer(&t, a, &mutex);
+ wg_allowedips_remove_by_peer(t, a);
test_negative(4, a, 1, 0, 0, 0);
test_negative(4, a, 64, 0, 0, 0);
test_negative(4, a, 128, 0, 0, 0);
test_negative(4, a, 192, 0, 0, 0);
test_negative(4, a, 255, 0, 0, 0);
- wg_allowedips_free(&t, &mutex);
- wg_allowedips_init(&t);
+ wg_allowedips_free(t);
+ wg_allowedips_init(t, &mutex);
insert(4, a, 192, 168, 0, 0, 16);
insert(4, a, 192, 168, 0, 0, 24);
- wg_allowedips_remove_by_peer(&t, a, &mutex);
+ wg_allowedips_remove_by_peer(t, a);
test_negative(4, a, 192, 168, 0, 1);
/* These will hit the WARN_ON(len >= 128) in free_node if something
@@ -608,12 +601,12 @@ bool __init wg_allowedips_selftest(void)
part = cpu_to_be64(~(1LLU << (i % 64)));
memset(&ip, 0xff, 16);
memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
- wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex);
+ wg_allowedips_insert_v6(t, &ip, 128, a);
}
- wg_allowedips_free(&t, &mutex);
+ wg_allowedips_free(t);
- wg_allowedips_init(&t);
+ wg_allowedips_init(t, &mutex);
insert(4, a, 192, 95, 5, 93, 27);
insert(6, a, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
insert(4, a, 10, 1, 0, 20, 29);
@@ -661,7 +654,8 @@ bool __init wg_allowedips_selftest(void)
pr_info("allowedips self-tests: pass\n");
free:
- wg_allowedips_free(&t, &mutex);
+ wg_allowedips_free(t);
+ kfree(t);
kfree(a);
kfree(b);
kfree(c);