diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index 6f6b65d3eed1..f4206001e2fe 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -111,7 +111,7 @@ static int __mptcp_socket_create(struct mptcp_sock *msk) if (err) return err; - msk->first = ssock->sk; + WRITE_ONCE(msk->first, ssock->sk); WRITE_ONCE(msk->subflow, ssock); subflow = mptcp_subflow_ctx(ssock->sk); list_add(&subflow->node, &msk->conn_list); @@ -2405,7 +2405,7 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, sock_put(ssk); if (ssk == msk->first) - msk->first = NULL; + WRITE_ONCE(msk->first, NULL); out: if (ssk == msk->last_snd) @@ -2706,7 +2706,7 @@ static int __mptcp_init_sock(struct sock *sk) WRITE_ONCE(msk->rmem_released, 0); msk->timer_ival = TCP_RTO_MIN; - msk->first = NULL; + WRITE_ONCE(msk->first, NULL); inet_csk(sk)->icsk_sync_mss = mptcp_sync_mss; WRITE_ONCE(msk->csum_enabled, mptcp_is_checksum_enabled(sock_net(sk))); WRITE_ONCE(msk->allow_infinite_fallback, true);