diff options
Diffstat (limited to 'drivers/net/wireguard/allowedips.c')
-rw-r--r-- | drivers/net/wireguard/allowedips.c | 87 |
1 files changed, 62 insertions, 25 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c index 3725e9c..c52782b 100644 --- a/drivers/net/wireguard/allowedips.c +++ b/drivers/net/wireguard/allowedips.c @@ -66,8 +66,7 @@ static void root_remove_peer_lists(struct allowedips_node *root) } } -static void walk_remove_by_peer(struct allowedips_node __rcu **top, - struct wg_peer *peer, struct mutex *lock) +static void walk_cleanup_empty_nodes(struct allowedips_node __rcu **top, struct mutex *lock) { #define REF(p) rcu_access_pointer(p) #define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock)) @@ -80,7 +79,7 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct allowedips_node *node, *prev; unsigned int len; - if (unlikely(!peer || !REF(*top))) + if (unlikely(!REF(*top))) return; for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) { @@ -100,16 +99,11 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, if (REF(node->bit[1])) PUSH(&node->bit[1]); } else { - if (rcu_dereference_protected(node->peer, - lockdep_is_held(lock)) == peer) { - RCU_INIT_POINTER(node->peer, NULL); - list_del_init(&node->peer_list); - if (!node->bit[0] || !node->bit[1]) { - rcu_assign_pointer(*nptr, DEREF( - &node->bit[!REF(node->bit[0])])); - kfree_rcu(node, rcu); - node = DEREF(nptr); - } + if (!rcu_dereference_protected(node->peer, lockdep_is_held(lock)) && + (!node->bit[0] || !node->bit[1])) { + rcu_assign_pointer(*nptr, DEREF(&node->bit[!REF(node->bit[0])])); + kfree_rcu(node, rcu); + node = DEREF(nptr); } --len; } @@ -281,29 +275,54 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, return 0; } -void wg_allowedips_init(struct allowedips *table) +static void wg_allowedips_cleanup(struct work_struct *work) +{ + struct delayed_work *dwork = to_delayed_work(work); + struct allowedips *table = container_of(dwork, struct allowedips, cleanup_work); + + if (!mutex_trylock(table->update_lock)) { + queue_delayed_work(system_power_efficient_wq, dwork, HZ / 4); + return; + } + if (table->dirty4) { + walk_cleanup_empty_nodes(&table->root4, table->update_lock); + table->dirty4 = false; + } + if (table->dirty6) { + walk_cleanup_empty_nodes(&table->root6, table->update_lock); + table->dirty6 = false; + } + mutex_unlock(table->update_lock); +} + +void wg_allowedips_init(struct allowedips *table, struct mutex *update_lock) { table->root4 = table->root6 = NULL; + table->update_lock = update_lock; table->seq = 1; + table->dirty4 = table->dirty6 = false; + INIT_DEFERRABLE_WORK(&table->cleanup_work, wg_allowedips_cleanup); } -void wg_allowedips_free(struct allowedips *table, struct mutex *lock) +void wg_allowedips_free(struct allowedips *table) { struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; ++table->seq; + table->dirty4 = table->dirty6 = false; + cancel_delayed_work_sync(&table->cleanup_work); RCU_INIT_POINTER(table->root4, NULL); RCU_INIT_POINTER(table->root6, NULL); if (rcu_access_pointer(old4)) { struct allowedips_node *node = rcu_dereference_protected(old4, - lockdep_is_held(lock)); + lockdep_is_held(table->update_lock)); root_remove_peer_lists(node); call_rcu(&node->rcu, root_free_rcu); } if (rcu_access_pointer(old6)) { struct allowedips_node *node = rcu_dereference_protected(old6, - lockdep_is_held(lock)); + lockdep_is_held(table->update_lock)); root_remove_peer_lists(node); call_rcu(&node->rcu, root_free_rcu); @@ -311,33 +330,51 @@ void wg_allowedips_free(struct allowedips *table, struct mutex *lock) } int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, - u8 cidr, struct wg_peer *peer, struct mutex *lock) + u8 cidr, struct wg_peer *peer) { /* Aligned so it can be passed to fls */ u8 key[4] __aligned(__alignof(u32)); ++table->seq; + if (table->dirty4) { + walk_cleanup_empty_nodes(&table->root4, table->update_lock); + table->dirty4 = false; + } swap_endian(key, (const u8 *)ip, 32); - return add(&table->root4, 32, key, cidr, peer, lock); + return add(&table->root4, 32, key, cidr, peer, table->update_lock); } int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, - u8 cidr, struct wg_peer *peer, struct mutex *lock) + u8 cidr, struct wg_peer *peer) { /* Aligned so it can be passed to fls64 */ u8 key[16] __aligned(__alignof(u64)); ++table->seq; + if (table->dirty6) { + walk_cleanup_empty_nodes(&table->root6, table->update_lock); + table->dirty6 = false; + } swap_endian(key, (const u8 *)ip, 128); - return add(&table->root6, 128, key, cidr, peer, lock); + return add(&table->root6, 128, key, cidr, peer, table->update_lock); } -void wg_allowedips_remove_by_peer(struct allowedips *table, - struct wg_peer *peer, struct mutex *lock) +void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer) { + struct allowedips_node *node, *tmp; + + lockdep_assert_held(table->update_lock); ++table->seq; - walk_remove_by_peer(&table->root4, peer, lock); - walk_remove_by_peer(&table->root6, peer, lock); + + list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) { + list_del_init(&node->peer_list); + RCU_INIT_POINTER(node->peer, NULL); + table->dirty4 |= node->bitlen == 32; + table->dirty6 |= node->bitlen == 128; + } + + if (table->dirty4 || table->dirty6) + queue_delayed_work(system_power_efficient_wq, &table->cleanup_work, HZ); } int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) |