summaryrefslogtreecommitdiff
path: root/drivers/net/wireguard/allowedips.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard/allowedips.c')
-rw-r--r--drivers/net/wireguard/allowedips.c87
1 files changed, 62 insertions, 25 deletions
diff --git a/drivers/net/wireguard/allowedips.c b/drivers/net/wireguard/allowedips.c
index 3725e9c..c52782b 100644
--- a/drivers/net/wireguard/allowedips.c
+++ b/drivers/net/wireguard/allowedips.c
@@ -66,8 +66,7 @@ static void root_remove_peer_lists(struct allowedips_node *root)
}
}
-static void walk_remove_by_peer(struct allowedips_node __rcu **top,
- struct wg_peer *peer, struct mutex *lock)
+static void walk_cleanup_empty_nodes(struct allowedips_node __rcu **top, struct mutex *lock)
{
#define REF(p) rcu_access_pointer(p)
#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
@@ -80,7 +79,7 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top,
struct allowedips_node *node, *prev;
unsigned int len;
- if (unlikely(!peer || !REF(*top)))
+ if (unlikely(!REF(*top)))
return;
for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
@@ -100,16 +99,11 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top,
if (REF(node->bit[1]))
PUSH(&node->bit[1]);
} else {
- if (rcu_dereference_protected(node->peer,
- lockdep_is_held(lock)) == peer) {
- RCU_INIT_POINTER(node->peer, NULL);
- list_del_init(&node->peer_list);
- if (!node->bit[0] || !node->bit[1]) {
- rcu_assign_pointer(*nptr, DEREF(
- &node->bit[!REF(node->bit[0])]));
- kfree_rcu(node, rcu);
- node = DEREF(nptr);
- }
+ if (!rcu_dereference_protected(node->peer, lockdep_is_held(lock)) &&
+ (!node->bit[0] || !node->bit[1])) {
+ rcu_assign_pointer(*nptr, DEREF(&node->bit[!REF(node->bit[0])]));
+ kfree_rcu(node, rcu);
+ node = DEREF(nptr);
}
--len;
}
@@ -281,29 +275,54 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
return 0;
}
-void wg_allowedips_init(struct allowedips *table)
+static void wg_allowedips_cleanup(struct work_struct *work)
+{
+ struct delayed_work *dwork = to_delayed_work(work);
+ struct allowedips *table = container_of(dwork, struct allowedips, cleanup_work);
+
+ if (!mutex_trylock(table->update_lock)) {
+ queue_delayed_work(system_power_efficient_wq, dwork, HZ / 4);
+ return;
+ }
+ if (table->dirty4) {
+ walk_cleanup_empty_nodes(&table->root4, table->update_lock);
+ table->dirty4 = false;
+ }
+ if (table->dirty6) {
+ walk_cleanup_empty_nodes(&table->root6, table->update_lock);
+ table->dirty6 = false;
+ }
+ mutex_unlock(table->update_lock);
+}
+
+void wg_allowedips_init(struct allowedips *table, struct mutex *update_lock)
{
table->root4 = table->root6 = NULL;
+ table->update_lock = update_lock;
table->seq = 1;
+ table->dirty4 = table->dirty6 = false;
+ INIT_DEFERRABLE_WORK(&table->cleanup_work, wg_allowedips_cleanup);
}
-void wg_allowedips_free(struct allowedips *table, struct mutex *lock)
+void wg_allowedips_free(struct allowedips *table)
{
struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6;
++table->seq;
+ table->dirty4 = table->dirty6 = false;
+ cancel_delayed_work_sync(&table->cleanup_work);
RCU_INIT_POINTER(table->root4, NULL);
RCU_INIT_POINTER(table->root6, NULL);
if (rcu_access_pointer(old4)) {
struct allowedips_node *node = rcu_dereference_protected(old4,
- lockdep_is_held(lock));
+ lockdep_is_held(table->update_lock));
root_remove_peer_lists(node);
call_rcu(&node->rcu, root_free_rcu);
}
if (rcu_access_pointer(old6)) {
struct allowedips_node *node = rcu_dereference_protected(old6,
- lockdep_is_held(lock));
+ lockdep_is_held(table->update_lock));
root_remove_peer_lists(node);
call_rcu(&node->rcu, root_free_rcu);
@@ -311,33 +330,51 @@ void wg_allowedips_free(struct allowedips *table, struct mutex *lock)
}
int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
- u8 cidr, struct wg_peer *peer, struct mutex *lock)
+ u8 cidr, struct wg_peer *peer)
{
/* Aligned so it can be passed to fls */
u8 key[4] __aligned(__alignof(u32));
++table->seq;
+ if (table->dirty4) {
+ walk_cleanup_empty_nodes(&table->root4, table->update_lock);
+ table->dirty4 = false;
+ }
swap_endian(key, (const u8 *)ip, 32);
- return add(&table->root4, 32, key, cidr, peer, lock);
+ return add(&table->root4, 32, key, cidr, peer, table->update_lock);
}
int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
- u8 cidr, struct wg_peer *peer, struct mutex *lock)
+ u8 cidr, struct wg_peer *peer)
{
/* Aligned so it can be passed to fls64 */
u8 key[16] __aligned(__alignof(u64));
++table->seq;
+ if (table->dirty6) {
+ walk_cleanup_empty_nodes(&table->root6, table->update_lock);
+ table->dirty6 = false;
+ }
swap_endian(key, (const u8 *)ip, 128);
- return add(&table->root6, 128, key, cidr, peer, lock);
+ return add(&table->root6, 128, key, cidr, peer, table->update_lock);
}
-void wg_allowedips_remove_by_peer(struct allowedips *table,
- struct wg_peer *peer, struct mutex *lock)
+void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer)
{
+ struct allowedips_node *node, *tmp;
+
+ lockdep_assert_held(table->update_lock);
++table->seq;
- walk_remove_by_peer(&table->root4, peer, lock);
- walk_remove_by_peer(&table->root6, peer, lock);
+
+ list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
+ list_del_init(&node->peer_list);
+ RCU_INIT_POINTER(node->peer, NULL);
+ table->dirty4 |= node->bitlen == 32;
+ table->dirty6 |= node->bitlen == 128;
+ }
+
+ if (table->dirty4 || table->dirty6)
+ queue_delayed_work(system_power_efficient_wq, &table->cleanup_work, HZ);
}
int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)