diff --git a/include/net/busy_poll.h b/include/net/busy_poll.h index 4202c609bb0b..7994455ec714 100644 --- a/include/net/busy_poll.h +++ b/include/net/busy_poll.h @@ -133,7 +133,7 @@ static inline void sk_mark_napi_id(struct sock *sk, const struct sk_buff *skb) if (unlikely(READ_ONCE(sk->sk_napi_id) != skb->napi_id)) WRITE_ONCE(sk->sk_napi_id, skb->napi_id); #endif - sk_rx_queue_set(sk, skb); + sk_rx_queue_update(sk, skb); } static inline void __sk_mark_napi_id_once(struct sock *sk, unsigned int napi_id) diff --git a/include/net/sock.h b/include/net/sock.h index 715cdb4b2b79..bea21ff70e74 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1913,18 +1913,31 @@ static inline int sk_tx_queue_get(const struct sock *sk) return -1; } -static inline void sk_rx_queue_set(struct sock *sk, const struct sk_buff *skb) +static inline void __sk_rx_queue_set(struct sock *sk, + const struct sk_buff *skb, + bool force_set) { #ifdef CONFIG_SOCK_RX_QUEUE_MAPPING if (skb_rx_queue_recorded(skb)) { u16 rx_queue = skb_get_rx_queue(skb); - if (unlikely(READ_ONCE(sk->sk_rx_queue_mapping) != rx_queue)) + if (force_set || + unlikely(READ_ONCE(sk->sk_rx_queue_mapping) != rx_queue)) WRITE_ONCE(sk->sk_rx_queue_mapping, rx_queue); } #endif } +static inline void sk_rx_queue_set(struct sock *sk, const struct sk_buff *skb) +{ + __sk_rx_queue_set(sk, skb, true); +} + +static inline void sk_rx_queue_update(struct sock *sk, const struct sk_buff *skb) +{ + __sk_rx_queue_set(sk, skb, false); +} + static inline void sk_rx_queue_clear(struct sock *sk) { #ifdef CONFIG_SOCK_RX_QUEUE_MAPPING