diff options
Diffstat (limited to 'drivers/net/wireguard/socket.c')
-rw-r--r-- | drivers/net/wireguard/socket.c | 37 |
1 files changed, 27 insertions, 10 deletions
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c index f901802..62716cc 100644 --- a/drivers/net/wireguard/socket.c +++ b/drivers/net/wireguard/socket.c @@ -18,7 +18,8 @@ #include <net/ipv6.h> static int send4(struct wg_device *wg, struct sk_buff *skb, - struct endpoint *endpoint, u8 ds, struct dst_cache *cache) + struct endpoint *endpoint, u8 ds, struct dst_cache *cache, + rwlock_t *endpoint_lock) { struct flowi4 fl = { .saddr = endpoint->src4.s_addr, @@ -37,6 +38,9 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, rcu_read_lock_bh(); sock = rcu_dereference_bh(wg->sock4); + if (likely(sock)) + sock_hold(sock); + rcu_read_unlock_bh(); if (unlikely(!sock)) { ret = -ENONET; @@ -80,6 +84,8 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, if (cache) dst_cache_set_ip4(cache, &rt->dst, fl.saddr); } + if (endpoint_lock) + read_unlock_bh(endpoint_lock); skb->ignore_df = 1; udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds, @@ -88,14 +94,18 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, goto out; err: + if (endpoint_lock) + read_unlock_bh(endpoint_lock); kfree_skb(skb); out: - rcu_read_unlock_bh(); + if (likely(sock)) + sock_put(sock); return ret; } static int send6(struct wg_device *wg, struct sk_buff *skb, - struct endpoint *endpoint, u8 ds, struct dst_cache *cache) + struct endpoint *endpoint, u8 ds, struct dst_cache *cache, + rwlock_t *endpoint_lock) { #if IS_ENABLED(CONFIG_IPV6) struct flowi6 fl = { @@ -117,6 +127,9 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, rcu_read_lock_bh(); sock = rcu_dereference_bh(wg->sock6); + if (likely(sock)) + sock_hold(sock); + rcu_read_unlock_bh(); if (unlikely(!sock)) { ret = -ENONET; @@ -147,6 +160,8 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, if (cache) dst_cache_set_ip6(cache, dst, &fl.saddr); } + if (endpoint_lock) + read_unlock_bh(endpoint_lock); skb->ignore_df = 1; udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds, @@ -155,9 +170,12 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, goto out; err: + if (endpoint_lock) + read_unlock_bh(endpoint_lock); kfree_skb(skb); out: - rcu_read_unlock_bh(); + if (likely(sock)) + sock_put(sock); return ret; #else return -EAFNOSUPPORT; @@ -169,18 +187,17 @@ int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds) size_t skb_len = skb->len; int ret = -EAFNOSUPPORT; - read_lock_bh(&peer->endpoint_lock); + read_lock_bh(&peer->endpoint_lock); /* Unlocked by send4/send6 */ if (peer->endpoint.addr.sa_family == AF_INET) ret = send4(peer->device, skb, &peer->endpoint, ds, - &peer->endpoint_cache); + &peer->endpoint_cache, &peer->endpoint_lock); else if (peer->endpoint.addr.sa_family == AF_INET6) ret = send6(peer->device, skb, &peer->endpoint, ds, - &peer->endpoint_cache); + &peer->endpoint_cache, &peer->endpoint_lock); else dev_kfree_skb(skb); if (likely(!ret)) peer->tx_bytes += skb_len; - read_unlock_bh(&peer->endpoint_lock); return ret; } @@ -221,9 +238,9 @@ int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg, skb_put_data(skb, buffer, len); if (endpoint.addr.sa_family == AF_INET) - ret = send4(wg, skb, &endpoint, 0, NULL); + ret = send4(wg, skb, &endpoint, 0, NULL, NULL); else if (endpoint.addr.sa_family == AF_INET6) - ret = send6(wg, skb, &endpoint, 0, NULL); + ret = send6(wg, skb, &endpoint, 0, NULL, NULL); /* No other possibilities if the endpoint is valid, which it is, * as we checked above. */ |