diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2020-05-04 16:09:47 -0600 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2020-05-04 18:04:01 -0600 |
commit | 3795532d0efea055166dcc161ebd0c3cd4257da9 (patch) | |
tree | b492f6d924acf6c3ca992ea12e1dd074b6fefeca /drivers/net | |
parent | 33d6e27f7e95b5a46080bea803090be601ebf900 (diff) | |
download | wireguard-linux-trimmed-3795532d0efea055166dcc161ebd0c3cd4257da9.tar.gz wireguard-linux-trimmed-3795532d0efea055166dcc161ebd0c3cd4257da9.zip |
wireguard: socket: do not hold locks while transmitting packetsjd/shorter-socket-lock
Before, we followed this pattern for using the udp_tunnel api:
rcu_read_lock_bh();
sock = rcu_dereference(obj->sock);
...
udp_tunnel_xmit_skb(..., sock, ...);
rcu_read_unlock_bh();
This commit changes that to use a reference counter instead:
rcu_read_lock_bh();
sock = rcu_dereference(obj->sock);
sock_hold(sock);
rcu_read_unlock_bh();
...
udp_tunnel_xmit_skb(..., sock, ...);
sock_put(sock);
The advantage of the latter approach is that we now no longer hold any
locks while udp_tunnel_xmit_skb runs, since it could be somewhat slow on
systems with advanced qdisc or netfilter configurations. This should
avoid potential RCU stalls in those situations.
This commit makes sure we're holding neither the rcu read lock nor the
endpoint read lock when udp_tunnel_xmit_skb is called.
Fixes: a8f1bc7bdea3 ("net: WireGuard secure network tunnel")
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'drivers/net')
-rw-r--r-- | drivers/net/wireguard/socket.c | 37 |
1 files changed, 27 insertions, 10 deletions
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c index f901802..62716cc 100644 --- a/drivers/net/wireguard/socket.c +++ b/drivers/net/wireguard/socket.c @@ -18,7 +18,8 @@ #include <net/ipv6.h> static int send4(struct wg_device *wg, struct sk_buff *skb, - struct endpoint *endpoint, u8 ds, struct dst_cache *cache) + struct endpoint *endpoint, u8 ds, struct dst_cache *cache, + rwlock_t *endpoint_lock) { struct flowi4 fl = { .saddr = endpoint->src4.s_addr, @@ -37,6 +38,9 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, rcu_read_lock_bh(); sock = rcu_dereference_bh(wg->sock4); + if (likely(sock)) + sock_hold(sock); + rcu_read_unlock_bh(); if (unlikely(!sock)) { ret = -ENONET; @@ -80,6 +84,8 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, if (cache) dst_cache_set_ip4(cache, &rt->dst, fl.saddr); } + if (endpoint_lock) + read_unlock_bh(endpoint_lock); skb->ignore_df = 1; udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds, @@ -88,14 +94,18 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, goto out; err: + if (endpoint_lock) + read_unlock_bh(endpoint_lock); kfree_skb(skb); out: - rcu_read_unlock_bh(); + if (likely(sock)) + sock_put(sock); return ret; } static int send6(struct wg_device *wg, struct sk_buff *skb, - struct endpoint *endpoint, u8 ds, struct dst_cache *cache) + struct endpoint *endpoint, u8 ds, struct dst_cache *cache, + rwlock_t *endpoint_lock) { #if IS_ENABLED(CONFIG_IPV6) struct flowi6 fl = { @@ -117,6 +127,9 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, rcu_read_lock_bh(); sock = rcu_dereference_bh(wg->sock6); + if (likely(sock)) + sock_hold(sock); + rcu_read_unlock_bh(); if (unlikely(!sock)) { ret = -ENONET; @@ -147,6 +160,8 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, if (cache) dst_cache_set_ip6(cache, dst, &fl.saddr); } + if (endpoint_lock) + read_unlock_bh(endpoint_lock); skb->ignore_df = 1; udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds, @@ -155,9 +170,12 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, goto out; err: + if (endpoint_lock) + read_unlock_bh(endpoint_lock); kfree_skb(skb); out: - rcu_read_unlock_bh(); + if (likely(sock)) + sock_put(sock); return ret; #else return -EAFNOSUPPORT; @@ -169,18 +187,17 @@ int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds) size_t skb_len = skb->len; int ret = -EAFNOSUPPORT; - read_lock_bh(&peer->endpoint_lock); + read_lock_bh(&peer->endpoint_lock); /* Unlocked by send4/send6 */ if (peer->endpoint.addr.sa_family == AF_INET) ret = send4(peer->device, skb, &peer->endpoint, ds, - &peer->endpoint_cache); + &peer->endpoint_cache, &peer->endpoint_lock); else if (peer->endpoint.addr.sa_family == AF_INET6) ret = send6(peer->device, skb, &peer->endpoint, ds, - &peer->endpoint_cache); + &peer->endpoint_cache, &peer->endpoint_lock); else dev_kfree_skb(skb); if (likely(!ret)) peer->tx_bytes += skb_len; - read_unlock_bh(&peer->endpoint_lock); return ret; } @@ -221,9 +238,9 @@ int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg, skb_put_data(skb, buffer, len); if (endpoint.addr.sa_family == AF_INET) - ret = send4(wg, skb, &endpoint, 0, NULL); + ret = send4(wg, skb, &endpoint, 0, NULL, NULL); else if (endpoint.addr.sa_family == AF_INET6) - ret = send6(wg, skb, &endpoint, 0, NULL); + ret = send6(wg, skb, &endpoint, 0, NULL, NULL); /* No other possibilities if the endpoint is valid, which it is, * as we checked above. */ |