Merge branch 'net-packet-KCSAN'

Eric Dumazet says:

====================
net/packet: KCSAN awareness

This series is based on one syzbot report [1]

Seven 'flags/booleans' are converted to atomic bit variant.

po->xmit and po->tp_tstamp accesses get annotations.

[1]
BUG: KCSAN: data-race in packet_rcv / packet_setsockopt

read-write to 0xffff88813dbe84e4 of 1 bytes by task 12312 on cpu 0:
packet_setsockopt+0xb77/0xe60 net/packet/af_packet.c:3900
__sys_setsockopt+0x212/0x2b0 net/socket.c:2252
__do_sys_setsockopt net/socket.c:2263 [inline]
__se_sys_setsockopt net/socket.c:2260 [inline]
__x64_sys_setsockopt+0x62/0x70 net/socket.c:2260
do_syscall_x64 arch/x86/entry/common.c:50 [inline]
do_syscall_64+0x2b/0x70 arch/x86/entry/common.c:80
entry_SYSCALL_64_after_hwframe+0x63/0xcd

read to 0xffff88813dbe84e4 of 1 bytes by task 1911 on cpu 1:
packet_rcv+0x4b1/0xa40 net/packet/af_packet.c:2187
deliver_skb net/core/dev.c:2189 [inline]
dev_queue_xmit_nit+0x3a9/0x620 net/core/dev.c:2259
xmit_one+0x71/0x2a0 net/core/dev.c:3586
dev_hard_start_xmit+0x72/0x120 net/core/dev.c:3606
__dev_queue_xmit+0x91c/0x11c0 net/core/dev.c:4256
dev_queue_xmit include/linux/netdevice.h:3008 [inline]
neigh_hh_output include/net/neighbour.h:530 [inline]
neigh_output include/net/neighbour.h:544 [inline]
ip6_finish_output2+0x9e9/0xc30 net/ipv6/ip6_output.c:134
__ip6_finish_output net/ipv6/ip6_output.c:195 [inline]
ip6_finish_output+0x395/0x4f0 net/ipv6/ip6_output.c:206
NF_HOOK_COND include/linux/netfilter.h:291 [inline]
ip6_output+0x10e/0x210 net/ipv6/ip6_output.c:227
dst_output include/net/dst.h:445 [inline]
ip6_local_out+0x60/0x80 net/ipv6/output_core.c:161
ip6tunnel_xmit include/net/ip6_tunnel.h:161 [inline]
udp_tunnel6_xmit_skb+0x321/0x4a0 net/ipv6/ip6_udp_tunnel.c:109
send6+0x2ed/0x3b0 drivers/net/wireguard/socket.c:152
wg_socket_send_skb_to_peer+0xbb/0x120 drivers/net/wireguard/socket.c:178
wg_packet_create_data_done drivers/net/wireguard/send.c:251 [inline]
wg_packet_tx_worker+0x142/0x360 drivers/net/wireguard/send.c:276
process_one_work+0x3d3/0x720 kernel/workqueue.c:2289
worker_thread+0x618/0xa70 kernel/workqueue.c:2436
kthread+0x1a9/0x1e0 kernel/kthread.c:376
ret_from_fork+0x1f/0x30 arch/x86/entry/entry_64.S:306
====================

Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
This commit is contained in:
David S. Miller 2023-03-17 08:52:06 +00:00
commit 19a9fbc074
3 changed files with 87 additions and 63 deletions

View file

@ -307,7 +307,8 @@ static void packet_cached_dev_reset(struct packet_sock *po)
static bool packet_use_direct_xmit(const struct packet_sock *po)
{
return po->xmit == packet_direct_xmit;
/* Paired with WRITE_ONCE() in packet_setsockopt() */
return READ_ONCE(po->xmit) == packet_direct_xmit;
}
static u16 packet_pick_tx_queue(struct sk_buff *skb)
@ -339,14 +340,14 @@ static void __register_prot_hook(struct sock *sk)
{
struct packet_sock *po = pkt_sk(sk);
if (!po->running) {
if (!packet_sock_flag(po, PACKET_SOCK_RUNNING)) {
if (po->fanout)
__fanout_link(sk, po);
else
dev_add_pack(&po->prot_hook);
sock_hold(sk);
po->running = 1;
packet_sock_flag_set(po, PACKET_SOCK_RUNNING, 1);
}
}
@ -368,7 +369,7 @@ static void __unregister_prot_hook(struct sock *sk, bool sync)
lockdep_assert_held_once(&po->bind_lock);
po->running = 0;
packet_sock_flag_set(po, PACKET_SOCK_RUNNING, 0);
if (po->fanout)
__fanout_unlink(sk, po);
@ -388,7 +389,7 @@ static void unregister_prot_hook(struct sock *sk, bool sync)
{
struct packet_sock *po = pkt_sk(sk);
if (po->running)
if (packet_sock_flag(po, PACKET_SOCK_RUNNING))
__unregister_prot_hook(sk, sync);
}
@ -473,7 +474,7 @@ static __u32 __packet_set_timestamp(struct packet_sock *po, void *frame,
struct timespec64 ts;
__u32 ts_status;
if (!(ts_status = tpacket_get_timestamp(skb, &ts, po->tp_tstamp)))
if (!(ts_status = tpacket_get_timestamp(skb, &ts, READ_ONCE(po->tp_tstamp))))
return 0;
h.raw = frame;
@ -1306,22 +1307,23 @@ static int __packet_rcv_has_room(const struct packet_sock *po,
static int packet_rcv_has_room(struct packet_sock *po, struct sk_buff *skb)
{
int pressure, ret;
bool pressure;
int ret;
ret = __packet_rcv_has_room(po, skb);
pressure = ret != ROOM_NORMAL;
if (READ_ONCE(po->pressure) != pressure)
WRITE_ONCE(po->pressure, pressure);
if (packet_sock_flag(po, PACKET_SOCK_PRESSURE) != pressure)
packet_sock_flag_set(po, PACKET_SOCK_PRESSURE, pressure);
return ret;
}
static void packet_rcv_try_clear_pressure(struct packet_sock *po)
{
if (READ_ONCE(po->pressure) &&
if (packet_sock_flag(po, PACKET_SOCK_PRESSURE) &&
__packet_rcv_has_room(po, NULL) == ROOM_NORMAL)
WRITE_ONCE(po->pressure, 0);
packet_sock_flag_set(po, PACKET_SOCK_PRESSURE, false);
}
static void packet_sock_destruct(struct sock *sk)
@ -1408,7 +1410,8 @@ static unsigned int fanout_demux_rollover(struct packet_fanout *f,
i = j = min_t(int, po->rollover->sock, num - 1);
do {
po_next = pkt_sk(rcu_dereference(f->arr[i]));
if (po_next != po_skip && !READ_ONCE(po_next->pressure) &&
if (po_next != po_skip &&
!packet_sock_flag(po_next, PACKET_SOCK_PRESSURE) &&
packet_rcv_has_room(po_next, skb) == ROOM_NORMAL) {
if (i != j)
po->rollover->sock = i;
@ -1781,7 +1784,7 @@ static int fanout_add(struct sock *sk, struct fanout_args *args)
err = -EINVAL;
spin_lock(&po->bind_lock);
if (po->running &&
if (packet_sock_flag(po, PACKET_SOCK_RUNNING) &&
match->type == type &&
match->prot_hook.type == po->prot_hook.type &&
match->prot_hook.dev == po->prot_hook.dev) {
@ -2183,7 +2186,7 @@ static int packet_rcv(struct sk_buff *skb, struct net_device *dev,
sll = &PACKET_SKB_CB(skb)->sa.ll;
sll->sll_hatype = dev->type;
sll->sll_pkttype = skb->pkt_type;
if (unlikely(po->origdev))
if (unlikely(packet_sock_flag(po, PACKET_SOCK_ORIGDEV)))
sll->sll_ifindex = orig_dev->ifindex;
else
sll->sll_ifindex = dev->ifindex;
@ -2308,7 +2311,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
netoff = TPACKET_ALIGN(po->tp_hdrlen +
(maclen < 16 ? 16 : maclen)) +
po->tp_reserve;
if (po->has_vnet_hdr) {
if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
netoff += sizeof(struct virtio_net_hdr);
do_vnet = true;
}
@ -2402,7 +2405,8 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
* closer to the time of capture.
*/
ts_status = tpacket_get_timestamp(skb, &ts,
po->tp_tstamp | SOF_TIMESTAMPING_SOFTWARE);
READ_ONCE(po->tp_tstamp) |
SOF_TIMESTAMPING_SOFTWARE);
if (!ts_status)
ktime_get_real_ts64(&ts);
@ -2460,7 +2464,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
sll->sll_hatype = dev->type;
sll->sll_protocol = skb->protocol;
sll->sll_pkttype = skb->pkt_type;
if (unlikely(po->origdev))
if (unlikely(packet_sock_flag(po, PACKET_SOCK_ORIGDEV)))
sll->sll_ifindex = orig_dev->ifindex;
else
sll->sll_ifindex = dev->ifindex;
@ -2670,7 +2674,7 @@ static int tpacket_parse_header(struct packet_sock *po, void *frame,
return -EMSGSIZE;
}
if (unlikely(po->tp_tx_has_off)) {
if (unlikely(packet_sock_flag(po, PACKET_SOCK_TX_HAS_OFF))) {
int off_min, off_max;
off_min = po->tp_hdrlen - sizeof(struct sockaddr_ll);
@ -2778,7 +2782,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
size_max = po->tx_ring.frame_size
- (po->tp_hdrlen - sizeof(struct sockaddr_ll));
if ((size_max > dev->mtu + reserve + VLAN_HLEN) && !po->has_vnet_hdr)
if ((size_max > dev->mtu + reserve + VLAN_HLEN) &&
!packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR))
size_max = dev->mtu + reserve + VLAN_HLEN;
reinit_completion(&po->skb_completion);
@ -2807,7 +2812,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
status = TP_STATUS_SEND_REQUEST;
hlen = LL_RESERVED_SPACE(dev);
tlen = dev->needed_tailroom;
if (po->has_vnet_hdr) {
if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
vnet_hdr = data;
data += sizeof(*vnet_hdr);
tp_len -= sizeof(*vnet_hdr);
@ -2835,13 +2840,13 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
addr, hlen, copylen, &sockc);
if (likely(tp_len >= 0) &&
tp_len > dev->mtu + reserve &&
!po->has_vnet_hdr &&
!packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR) &&
!packet_extra_vlan_len_allowed(dev, skb))
tp_len = -EMSGSIZE;
if (unlikely(tp_len < 0)) {
tpacket_error:
if (po->tp_loss) {
if (packet_sock_flag(po, PACKET_SOCK_TP_LOSS)) {
__packet_set_status(po, ph,
TP_STATUS_AVAILABLE);
packet_increment_head(&po->tx_ring);
@ -2854,7 +2859,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
}
}
if (po->has_vnet_hdr) {
if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
if (virtio_net_hdr_to_skb(skb, vnet_hdr, vio_le())) {
tp_len = -EINVAL;
goto tpacket_error;
@ -2867,7 +2872,8 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
packet_inc_pending(&po->tx_ring);
status = TP_STATUS_SEND_REQUEST;
err = po->xmit(skb);
/* Paired with WRITE_ONCE() in packet_setsockopt() */
err = READ_ONCE(po->xmit)(skb);
if (unlikely(err != 0)) {
if (err > 0)
err = net_xmit_errno(err);
@ -2988,7 +2994,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
if (sock->type == SOCK_RAW)
reserve = dev->hard_header_len;
if (po->has_vnet_hdr) {
if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
err = packet_snd_vnet_parse(msg, &len, &vnet_hdr);
if (err)
goto out_unlock;
@ -3070,7 +3076,8 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
virtio_net_hdr_set_proto(skb, &vnet_hdr);
}
err = po->xmit(skb);
/* Paired with WRITE_ONCE() in packet_setsockopt() */
err = READ_ONCE(po->xmit)(skb);
if (unlikely(err != 0)) {
if (err > 0)
err = net_xmit_errno(err);
@ -3217,7 +3224,7 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex,
if (need_rehook) {
dev_hold(dev);
if (po->running) {
if (packet_sock_flag(po, PACKET_SOCK_RUNNING)) {
rcu_read_unlock();
/* prevents packet_notifier() from calling
* register_prot_hook()
@ -3230,7 +3237,7 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex,
dev->ifindex);
}
BUG_ON(po->running);
BUG_ON(packet_sock_flag(po, PACKET_SOCK_RUNNING));
WRITE_ONCE(po->num, proto);
po->prot_hook.type = proto;
@ -3447,7 +3454,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
packet_rcv_try_clear_pressure(pkt_sk(sk));
if (pkt_sk(sk)->has_vnet_hdr) {
if (packet_sock_flag(pkt_sk(sk), PACKET_SOCK_HAS_VNET_HDR)) {
err = packet_rcv_vnet(msg, skb, &len);
if (err)
goto out_free;
@ -3511,7 +3518,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
memcpy(msg->msg_name, &PACKET_SKB_CB(skb)->sa, copy_len);
}
if (pkt_sk(sk)->auxdata) {
if (packet_sock_flag(pkt_sk(sk), PACKET_SOCK_AUXDATA)) {
struct tpacket_auxdata aux;
aux.tp_status = TP_STATUS_USER;
@ -3882,7 +3889,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) {
ret = -EBUSY;
} else {
po->tp_loss = !!val;
packet_sock_flag_set(po, PACKET_SOCK_TP_LOSS, val);
ret = 0;
}
release_sock(sk);
@ -3897,9 +3904,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT;
lock_sock(sk);
po->auxdata = !!val;
release_sock(sk);
packet_sock_flag_set(po, PACKET_SOCK_AUXDATA, val);
return 0;
}
case PACKET_ORIGDEV:
@ -3911,9 +3916,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT;
lock_sock(sk);
po->origdev = !!val;
release_sock(sk);
packet_sock_flag_set(po, PACKET_SOCK_ORIGDEV, val);
return 0;
}
case PACKET_VNET_HDR:
@ -3931,7 +3934,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) {
ret = -EBUSY;
} else {
po->has_vnet_hdr = !!val;
packet_sock_flag_set(po, PACKET_SOCK_HAS_VNET_HDR, val);
ret = 0;
}
release_sock(sk);
@ -3946,7 +3949,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT;
po->tp_tstamp = val;
WRITE_ONCE(po->tp_tstamp, val);
return 0;
}
case PACKET_FANOUT:
@ -3993,7 +3996,7 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
lock_sock(sk);
if (!po->rx_ring.pg_vec && !po->tx_ring.pg_vec)
po->tp_tx_has_off = !!val;
packet_sock_flag_set(po, PACKET_SOCK_TX_HAS_OFF, val);
release_sock(sk);
return 0;
@ -4007,7 +4010,8 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
if (copy_from_sockptr(&val, optval, sizeof(val)))
return -EFAULT;
po->xmit = val ? packet_direct_xmit : dev_queue_xmit;
/* Paired with all lockless reads of po->xmit */
WRITE_ONCE(po->xmit, val ? packet_direct_xmit : dev_queue_xmit);
return 0;
}
default:
@ -4058,13 +4062,13 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
break;
case PACKET_AUXDATA:
val = po->auxdata;
val = packet_sock_flag(po, PACKET_SOCK_AUXDATA);
break;
case PACKET_ORIGDEV:
val = po->origdev;
val = packet_sock_flag(po, PACKET_SOCK_ORIGDEV);
break;
case PACKET_VNET_HDR:
val = po->has_vnet_hdr;
val = packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR);
break;
case PACKET_VERSION:
val = po->tp_version;
@ -4094,10 +4098,10 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
val = po->tp_reserve;
break;
case PACKET_LOSS:
val = po->tp_loss;
val = packet_sock_flag(po, PACKET_SOCK_TP_LOSS);
break;
case PACKET_TIMESTAMP:
val = po->tp_tstamp;
val = READ_ONCE(po->tp_tstamp);
break;
case PACKET_FANOUT:
val = (po->fanout ?
@ -4119,7 +4123,7 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
lv = sizeof(rstats);
break;
case PACKET_TX_HAS_OFF:
val = po->tp_tx_has_off;
val = packet_sock_flag(po, PACKET_SOCK_TX_HAS_OFF);
break;
case PACKET_QDISC_BYPASS:
val = packet_use_direct_xmit(po);
@ -4157,7 +4161,7 @@ static int packet_notifier(struct notifier_block *this,
case NETDEV_DOWN:
if (dev->ifindex == po->ifindex) {
spin_lock(&po->bind_lock);
if (po->running) {
if (packet_sock_flag(po, PACKET_SOCK_RUNNING)) {
__unregister_prot_hook(sk, false);
sk->sk_err = ENETDOWN;
if (!sock_flag(sk, SOCK_DEAD))
@ -4468,7 +4472,7 @@ static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u,
/* Detach socket from network */
spin_lock(&po->bind_lock);
was_running = po->running;
was_running = packet_sock_flag(po, PACKET_SOCK_RUNNING);
num = po->num;
if (was_running) {
WRITE_ONCE(po->num, 0);
@ -4679,7 +4683,7 @@ static int packet_seq_show(struct seq_file *seq, void *v)
s->sk_type,
ntohs(READ_ONCE(po->num)),
READ_ONCE(po->ifindex),
po->running,
packet_sock_flag(po, PACKET_SOCK_RUNNING),
atomic_read(&s->sk_rmem_alloc),
from_kuid_munged(seq_user_ns(seq), sock_i_uid(s)),
sock_i_ino(s));

View file

@ -18,18 +18,18 @@ static int pdiag_put_info(const struct packet_sock *po, struct sk_buff *nlskb)
pinfo.pdi_version = po->tp_version;
pinfo.pdi_reserve = po->tp_reserve;
pinfo.pdi_copy_thresh = po->copy_thresh;
pinfo.pdi_tstamp = po->tp_tstamp;
pinfo.pdi_tstamp = READ_ONCE(po->tp_tstamp);
pinfo.pdi_flags = 0;
if (po->running)
if (packet_sock_flag(po, PACKET_SOCK_RUNNING))
pinfo.pdi_flags |= PDI_RUNNING;
if (po->auxdata)
if (packet_sock_flag(po, PACKET_SOCK_AUXDATA))
pinfo.pdi_flags |= PDI_AUXDATA;
if (po->origdev)
if (packet_sock_flag(po, PACKET_SOCK_ORIGDEV))
pinfo.pdi_flags |= PDI_ORIGDEV;
if (po->has_vnet_hdr)
if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR))
pinfo.pdi_flags |= PDI_VNETHDR;
if (po->tp_loss)
if (packet_sock_flag(po, PACKET_SOCK_TP_LOSS))
pinfo.pdi_flags |= PDI_LOSS;
return nla_put(nlskb, PACKET_DIAG_INFO, sizeof(pinfo), &pinfo);

View file

@ -116,13 +116,7 @@ struct packet_sock {
int copy_thresh;
spinlock_t bind_lock;
struct mutex pg_vec_lock;
unsigned int running; /* bind_lock must be held */
unsigned int auxdata:1, /* writer must hold sock lock */
origdev:1,
has_vnet_hdr:1,
tp_loss:1,
tp_tx_has_off:1;
int pressure;
unsigned long flags;
int ifindex; /* bound device */
__be16 num;
struct packet_rollover *rollover;
@ -144,4 +138,30 @@ static inline struct packet_sock *pkt_sk(struct sock *sk)
return (struct packet_sock *)sk;
}
enum packet_sock_flags {
PACKET_SOCK_ORIGDEV,
PACKET_SOCK_AUXDATA,
PACKET_SOCK_TX_HAS_OFF,
PACKET_SOCK_TP_LOSS,
PACKET_SOCK_HAS_VNET_HDR,
PACKET_SOCK_RUNNING,
PACKET_SOCK_PRESSURE,
};
static inline void packet_sock_flag_set(struct packet_sock *po,
enum packet_sock_flags flag,
bool val)
{
if (val)
set_bit(flag, &po->flags);
else
clear_bit(flag, &po->flags);
}
static inline bool packet_sock_flag(const struct packet_sock *po,
enum packet_sock_flags flag)
{
return test_bit(flag, &po->flags);
}
#endif