diff options
-rw-r--r-- | drivers/net/wireguard/allowedips.c | 87 | ||||
-rw-r--r-- | drivers/net/wireguard/allowedips.h | 14 | ||||
-rw-r--r-- | drivers/net/wireguard/device.c | 2 | ||||
-rw-r--r-- | drivers/net/wireguard/netlink.c | 15 | ||||
-rw-r--r-- | drivers/net/wireguard/peer.c | 5 | ||||
-rw-r--r-- | drivers/net/wireguard/selftest/allowedips.c | 106 |
6 files changed, 128 insertions, 101 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c index 3725e9c..c52782b 100644 --- a/drivers/net/wireguard/allowedips.c +++ b/drivers/net/wireguard/allowedips.c @@ -66,8 +66,7 @@ static void root_remove_peer_lists(struct allowedips_node *root) } } -static void walk_remove_by_peer(struct allowedips_node __rcu **top, - struct wg_peer *peer, struct mutex *lock) +static void walk_cleanup_empty_nodes(struct allowedips_node __rcu **top, struct mutex *lock) { #define REF(p) rcu_access_pointer(p) #define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock)) @@ -80,7 +79,7 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct allowedips_node *node, *prev; unsigned int len; - if (unlikely(!peer || !REF(*top))) + if (unlikely(!REF(*top))) return; for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) { @@ -100,16 +99,11 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, if (REF(node->bit[1])) PUSH(&node->bit[1]); } else { - if (rcu_dereference_protected(node->peer, - lockdep_is_held(lock)) == peer) { - RCU_INIT_POINTER(node->peer, NULL); - list_del_init(&node->peer_list); - if (!node->bit[0] || !node->bit[1]) { - rcu_assign_pointer(*nptr, DEREF( - &node->bit[!REF(node->bit[0])])); - kfree_rcu(node, rcu); - node = DEREF(nptr); - } + if (!rcu_dereference_protected(node->peer, lockdep_is_held(lock)) && + (!node->bit[0] || !node->bit[1])) { + rcu_assign_pointer(*nptr, DEREF(&node->bit[!REF(node->bit[0])])); + kfree_rcu(node, rcu); + node = DEREF(nptr); } --len; } @@ -281,29 +275,54 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, return 0; } -void wg_allowedips_init(struct allowedips *table) +static void wg_allowedips_cleanup(struct work_struct *work) +{ + struct delayed_work *dwork = to_delayed_work(work); + struct allowedips *table = container_of(dwork, struct allowedips, cleanup_work); + + if (!mutex_trylock(table->update_lock)) { + queue_delayed_work(system_power_efficient_wq, dwork, HZ / 4); + return; + } + if (table->dirty4) { + walk_cleanup_empty_nodes(&table->root4, table->update_lock); + table->dirty4 = false; + } + if (table->dirty6) { + walk_cleanup_empty_nodes(&table->root6, table->update_lock); + table->dirty6 = false; + } + mutex_unlock(table->update_lock); +} + +void wg_allowedips_init(struct allowedips *table, struct mutex *update_lock) { table->root4 = table->root6 = NULL; + table->update_lock = update_lock; table->seq = 1; + table->dirty4 = table->dirty6 = false; + INIT_DEFERRABLE_WORK(&table->cleanup_work, wg_allowedips_cleanup); } -void wg_allowedips_free(struct allowedips *table, struct mutex *lock) +void wg_allowedips_free(struct allowedips *table) { struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; ++table->seq; + table->dirty4 = table->dirty6 = false; + cancel_delayed_work_sync(&table->cleanup_work); RCU_INIT_POINTER(table->root4, NULL); RCU_INIT_POINTER(table->root6, NULL); if (rcu_access_pointer(old4)) { struct allowedips_node *node = rcu_dereference_protected(old4, - lockdep_is_held(lock)); + lockdep_is_held(table->update_lock)); root_remove_peer_lists(node); call_rcu(&node->rcu, root_free_rcu); } if (rcu_access_pointer(old6)) { struct allowedips_node *node = rcu_dereference_protected(old6, - lockdep_is_held(lock)); + lockdep_is_held(table->update_lock)); root_remove_peer_lists(node); call_rcu(&node->rcu, root_free_rcu); @@ -311,33 +330,51 @@ void wg_allowedips_free(struct allowedips *table, struct mutex *lock) } int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, - u8 cidr, struct wg_peer *peer, struct mutex *lock) + u8 cidr, struct wg_peer *peer) { /* Aligned so it can be passed to fls */ u8 key[4] __aligned(__alignof(u32)); ++table->seq; + if (table->dirty4) { + walk_cleanup_empty_nodes(&table->root4, table->update_lock); + table->dirty4 = false; + } swap_endian(key, (const u8 *)ip, 32); - return add(&table->root4, 32, key, cidr, peer, lock); + return add(&table->root4, 32, key, cidr, peer, table->update_lock); } int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, - u8 cidr, struct wg_peer *peer, struct mutex *lock) + u8 cidr, struct wg_peer *peer) { /* Aligned so it can be passed to fls64 */ u8 key[16] __aligned(__alignof(u64)); ++table->seq; + if (table->dirty6) { + walk_cleanup_empty_nodes(&table->root6, table->update_lock); + table->dirty6 = false; + } swap_endian(key, (const u8 *)ip, 128); - return add(&table->root6, 128, key, cidr, peer, lock); + return add(&table->root6, 128, key, cidr, peer, table->update_lock); } -void wg_allowedips_remove_by_peer(struct allowedips *table, - struct wg_peer *peer, struct mutex *lock) +void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer) { + struct allowedips_node *node, *tmp; + + lockdep_assert_held(table->update_lock); ++table->seq; - walk_remove_by_peer(&table->root4, peer, lock); - walk_remove_by_peer(&table->root6, peer, lock); + + list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { + list_del_init(&node->peer_list); + RCU_INIT_POINTER(node->peer, NULL); + table->dirty4 |= node->bitlen == 32; + table->dirty6 |= node->bitlen == 128; + } + + if (table->dirty4 || table->dirty6) + queue_delayed_work(system_power_efficient_wq, &table->cleanup_work, HZ); } int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) diff --git a/drivers/net/wireguard/allowedips.h b/drivers/net/wireguard/allowedips.h index e5c83ca..602c334 100644 --- a/drivers/net/wireguard/allowedips.h +++ b/drivers/net/wireguard/allowedips.h @@ -32,17 +32,19 @@ struct allowedips_node { struct allowedips { struct allowedips_node __rcu *root4; struct allowedips_node __rcu *root6; + struct mutex *update_lock; + struct delayed_work cleanup_work; u64 seq; + bool dirty4, dirty6; }; -void wg_allowedips_init(struct allowedips *table); -void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); +void wg_allowedips_init(struct allowedips *table, struct mutex *update_lock); +void wg_allowedips_free(struct allowedips *table); int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, - u8 cidr, struct wg_peer *peer, struct mutex *lock); + u8 cidr, struct wg_peer *peer); int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, - u8 cidr, struct wg_peer *peer, struct mutex *lock); -void wg_allowedips_remove_by_peer(struct allowedips *table, - struct wg_peer *peer, struct mutex *lock); + u8 cidr, struct wg_peer *peer); +void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer); /* The ip input pointer should be __aligned(__alignof(u64))) */ int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr); diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c index 551ddaa..610ce1d 100644 --- a/drivers/net/wireguard/device.c +++ b/drivers/net/wireguard/device.c @@ -299,7 +299,7 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, mutex_init(&wg->socket_update_lock); mutex_init(&wg->device_update_lock); skb_queue_head_init(&wg->incoming_handshakes); - wg_allowedips_init(&wg->peer_allowedips); + wg_allowedips_init(&wg->peer_allowedips, &wg->device_update_lock); wg_cookie_checker_init(&wg->cookie_checker, wg); INIT_LIST_HEAD(&wg->peer_list); wg->device_update_gen = 1; diff --git a/drivers/net/wireguard/netlink.c b/drivers/net/wireguard/netlink.c index d0f3b6d..9dcbc57 100644 --- a/drivers/net/wireguard/netlink.c +++ b/drivers/net/wireguard/netlink.c @@ -340,16 +340,12 @@ static int set_allowedip(struct wg_peer *peer, struct nlattr **attrs) if (family == AF_INET && cidr <= 32 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) - ret = wg_allowedips_insert_v4( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); + ret = wg_allowedips_insert_v4(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer); else if (family == AF_INET6 && cidr <= 128 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) - ret = wg_allowedips_insert_v6( - &peer->device->peer_allowedips, - nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, - &peer->device->device_update_lock); + ret = wg_allowedips_insert_v6(&peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer); return ret; } @@ -449,8 +445,7 @@ static int set_peer(struct wg_device *wg, struct nlattr **attrs) } if (flags & WGPEER_F_REPLACE_ALLOWEDIPS) - wg_allowedips_remove_by_peer(&wg->peer_allowedips, peer, - &wg->device_update_lock); + wg_allowedips_remove_by_peer(&wg->peer_allowedips, peer); if (attrs[WGPEER_A_ALLOWEDIPS]) { struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1]; diff --git a/drivers/net/wireguard/peer.c b/drivers/net/wireguard/peer.c index 3a042d2..3a14a37 100644 --- a/drivers/net/wireguard/peer.c +++ b/drivers/net/wireguard/peer.c @@ -81,8 +81,7 @@ static void peer_make_dead(struct wg_peer *peer) { /* Remove from configuration-time lookup structures. */ list_del_init(&peer->peer_list); - wg_allowedips_remove_by_peer(&peer->device->peer_allowedips, peer, - &peer->device->device_update_lock); + wg_allowedips_remove_by_peer(&peer->device->peer_allowedips, peer); wg_pubkey_hashtable_remove(peer->device->peer_hashtable, peer); /* Mark as dead, so that we don't allow jumping contexts after. */ @@ -172,7 +171,7 @@ void wg_peer_remove_all(struct wg_device *wg) lockdep_assert_held(&wg->device_update_lock); /* Avoid having to traverse individually for each one. */ - wg_allowedips_free(&wg->peer_allowedips, &wg->device_update_lock); + wg_allowedips_free(&wg->peer_allowedips); list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) { peer_make_dead(peer); diff --git a/drivers/net/wireguard/selftest/allowedips.c b/drivers/net/wireguard/selftest/allowedips.c index 0d2a43a..61202d5 100644 --- a/drivers/net/wireguard/selftest/allowedips.c +++ b/drivers/net/wireguard/selftest/allowedips.c @@ -273,22 +273,21 @@ static __init bool randomized_test(void) { unsigned int i, j, k, mutate_amount, cidr; u8 ip[16], mutate_mask[16], mutated[16]; - struct wg_peer **peers, *peer; + struct wg_peer **peers = kcalloc(NUM_PEERS, sizeof(*peers), GFP_KERNEL), *peer; struct horrible_allowedips h; DEFINE_MUTEX(mutex); - struct allowedips t; + struct allowedips *t = kmalloc(sizeof(*t), GFP_KERNEL); bool ret = false; - mutex_init(&mutex); - - wg_allowedips_init(&t); - horrible_allowedips_init(&h); - - peers = kcalloc(NUM_PEERS, sizeof(*peers), GFP_KERNEL); - if (unlikely(!peers)) { + if (unlikely(!t || !peers)) { pr_err("allowedips random self-test malloc: FAIL\n"); goto free; } + + mutex_init(&mutex); + wg_allowedips_init(t, &mutex); + horrible_allowedips_init(&h); + for (i = 0; i < NUM_PEERS; ++i) { peers[i] = kzalloc(sizeof(*peers[i]), GFP_KERNEL); if (unlikely(!peers[i])) { @@ -305,13 +304,11 @@ static __init bool randomized_test(void) prandom_bytes(ip, 4); cidr = prandom_u32_max(32) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (wg_allowedips_insert_v4(&t, (struct in_addr *)ip, cidr, - peer, &mutex) < 0) { + if (wg_allowedips_insert_v4(t, (struct in_addr *)ip, cidr, peer) < 0) { pr_err("allowedips random self-test malloc: FAIL\n"); goto free_locked; } - if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip, - cidr, peer) < 0) { + if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip, cidr, peer) < 0) { pr_err("allowedips random self-test malloc: FAIL\n"); goto free_locked; } @@ -331,14 +328,13 @@ static __init bool randomized_test(void) prandom_u32_max(256)); cidr = prandom_u32_max(32) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (wg_allowedips_insert_v4(&t, - (struct in_addr *)mutated, - cidr, peer, &mutex) < 0) { + if (wg_allowedips_insert_v4(t, (struct in_addr *)mutated, + cidr, peer) < 0) { pr_err("allowedips random self-test malloc: FAIL\n"); goto free_locked; } - if (horrible_allowedips_insert_v4(&h, - (struct in_addr *)mutated, cidr, peer)) { + if (horrible_allowedips_insert_v4(&h, (struct in_addr *)mutated, + cidr, peer)) { pr_err("allowedips random self-test malloc: FAIL\n"); goto free_locked; } @@ -349,13 +345,11 @@ static __init bool randomized_test(void) prandom_bytes(ip, 16); cidr = prandom_u32_max(128) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (wg_allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr, - peer, &mutex) < 0) { + if (wg_allowedips_insert_v6(t, (struct in6_addr *)ip, cidr, peer) < 0) { pr_err("allowedips random self-test malloc: FAIL\n"); goto free_locked; } - if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip, - cidr, peer) < 0) { + if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip, cidr, peer) < 0) { pr_err("allowedips random self-test malloc: FAIL\n"); goto free_locked; } @@ -375,15 +369,13 @@ static __init bool randomized_test(void) prandom_u32_max(256)); cidr = prandom_u32_max(128) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (wg_allowedips_insert_v6(&t, - (struct in6_addr *)mutated, - cidr, peer, &mutex) < 0) { + if (wg_allowedips_insert_v6(t, (struct in6_addr *)mutated, + cidr, peer) < 0) { pr_err("allowedips random self-test malloc: FAIL\n"); goto free_locked; } - if (horrible_allowedips_insert_v6( - &h, (struct in6_addr *)mutated, cidr, - peer)) { + if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)mutated, + cidr, peer)) { pr_err("allowedips random self-test malloc: FAIL\n"); goto free_locked; } @@ -393,13 +385,13 @@ static __init bool randomized_test(void) mutex_unlock(&mutex); if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) { - print_tree(t.root4, 32); - print_tree(t.root6, 128); + print_tree(t->root4, 32); + print_tree(t->root6, 128); } for (i = 0; i < NUM_QUERIES; ++i) { prandom_bytes(ip, 4); - if (lookup(t.root4, 32, ip) != + if (lookup(t->root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { pr_err("allowedips random self-test: FAIL\n"); goto free; @@ -408,7 +400,7 @@ static __init bool randomized_test(void) for (i = 0; i < NUM_QUERIES; ++i) { prandom_bytes(ip, 16); - if (lookup(t.root6, 128, ip) != + if (lookup(t->root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { pr_err("allowedips random self-test: FAIL\n"); goto free; @@ -419,7 +411,8 @@ static __init bool randomized_test(void) free: mutex_lock(&mutex); free_locked: - wg_allowedips_free(&t, &mutex); + wg_allowedips_free(t); + kfree(t); mutex_unlock(&mutex); horrible_allowedips_free(&h); if (peers) { @@ -466,8 +459,8 @@ static __init struct wg_peer *init_peer(void) } #define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \ - wg_allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \ - cidr, mem, &mutex) + wg_allowedips_insert_v##version(t, ip##version(ipa, ipb, ipc, ipd), \ + cidr, mem) #define maybe_fail() do { \ ++i; \ @@ -477,16 +470,16 @@ static __init struct wg_peer *init_peer(void) } \ } while (0) -#define test(version, mem, ipa, ipb, ipc, ipd) do { \ - bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \ - ip##version(ipa, ipb, ipc, ipd)) == (mem); \ - maybe_fail(); \ +#define test(version, mem, ipa, ipb, ipc, ipd) do { \ + bool _s = lookup(t->root##version, (version) == 4 ? 32 : 128, \ + ip##version(ipa, ipb, ipc, ipd)) == (mem); \ + maybe_fail(); \ } while (0) -#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \ - bool _s = lookup(t.root##version, (version) == 4 ? 32 : 128, \ - ip##version(ipa, ipb, ipc, ipd)) != (mem); \ - maybe_fail(); \ +#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \ + bool _s = lookup(t->root##version, (version) == 4 ? 32 : 128, \ + ip##version(ipa, ipb, ipc, ipd)) != (mem); \ + maybe_fail(); \ } while (0) #define test_boolean(cond) do { \ @@ -503,7 +496,7 @@ bool __init wg_allowedips_selftest(void) *g = init_peer(), *h = init_peer(); struct allowedips_node *iter_node; bool success = false; - struct allowedips t; + struct allowedips *t = kmalloc(sizeof(*t), GFP_KERNEL); DEFINE_MUTEX(mutex); struct in6_addr ip; size_t i = 0, count = 0; @@ -511,9 +504,9 @@ bool __init wg_allowedips_selftest(void) mutex_init(&mutex); mutex_lock(&mutex); - wg_allowedips_init(&t); + wg_allowedips_init(t, &mutex); - if (!a || !b || !c || !d || !e || !f || !g || !h) { + if (!a || !b || !c || !d || !e || !f || !g || !h || !t) { pr_err("allowedips self-test malloc: FAIL\n"); goto free; } @@ -547,8 +540,8 @@ bool __init wg_allowedips_selftest(void) insert(4, d, 10, 1, 0, 16, 29); if (IS_ENABLED(DEBUG_PRINT_TRIE_GRAPHVIZ)) { - print_tree(t.root4, 32); - print_tree(t.root6, 128); + print_tree(t->root4, 32); + print_tree(t->root6, 128); } success = true; @@ -587,18 +580,18 @@ bool __init wg_allowedips_selftest(void) insert(4, a, 128, 0, 0, 0, 32); insert(4, a, 192, 0, 0, 0, 32); insert(4, a, 255, 0, 0, 0, 32); - wg_allowedips_remove_by_peer(&t, a, &mutex); + wg_allowedips_remove_by_peer(t, a); test_negative(4, a, 1, 0, 0, 0); test_negative(4, a, 64, 0, 0, 0); test_negative(4, a, 128, 0, 0, 0); test_negative(4, a, 192, 0, 0, 0); test_negative(4, a, 255, 0, 0, 0); - wg_allowedips_free(&t, &mutex); - wg_allowedips_init(&t); + wg_allowedips_free(t); + wg_allowedips_init(t, &mutex); insert(4, a, 192, 168, 0, 0, 16); insert(4, a, 192, 168, 0, 0, 24); - wg_allowedips_remove_by_peer(&t, a, &mutex); + wg_allowedips_remove_by_peer(t, a); test_negative(4, a, 192, 168, 0, 1); /* These will hit the WARN_ON(len >= 128) in free_node if something @@ -608,12 +601,12 @@ bool __init wg_allowedips_selftest(void) part = cpu_to_be64(~(1LLU << (i % 64))); memset(&ip, 0xff, 16); memcpy((u8 *)&ip + (i < 64) * 8, &part, 8); - wg_allowedips_insert_v6(&t, &ip, 128, a, &mutex); + wg_allowedips_insert_v6(t, &ip, 128, a); } - wg_allowedips_free(&t, &mutex); + wg_allowedips_free(t); - wg_allowedips_init(&t); + wg_allowedips_init(t, &mutex); insert(4, a, 192, 95, 5, 93, 27); insert(6, a, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128); insert(4, a, 10, 1, 0, 20, 29); @@ -661,7 +654,8 @@ bool __init wg_allowedips_selftest(void) pr_info("allowedips self-tests: pass\n"); free: - wg_allowedips_free(&t, &mutex); + wg_allowedips_free(t); + kfree(t); kfree(a); kfree(b); kfree(c); |