tcp: allow again tcp_disconnect() when threads are waiting

As reported by Tom, .NET and applications build on top of it rely
on connect(AF_UNSPEC) to async cancel pending I/O operations on TCP
socket.

The blamed commit below caused a regression, as such cancellation
can now fail.

As suggested by Eric, this change addresses the problem explicitly
causing blocking I/O operation to terminate immediately (with an error)
when a concurrent disconnect() is executed.

Instead of tracking the number of threads blocked on a given socket,
track the number of disconnect() issued on such socket. If such counter
changes after a blocking operation releasing and re-acquiring the socket
lock, error out the current operation.

Fixes: 4faeee0cf8 ("tcp: deny tcp_disconnect() when threads are waiting")
Reported-by: Tom Deseyn <tdeseyn@redhat.com>
Closes: https://bugzilla.redhat.com/show_bug.cgi?id=1886305
Suggested-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Link: https://lore.kernel.org/r/f3b95e47e3dbed840960548aebaa8d954372db41.1697008693.git.pabeni@redhat.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Paolo Abeni 2023-10-11 09:20:55 +02:00 committed by Jakub Kicinski
parent 242e34500a
commit 419ce133ab
10 changed files with 80 additions and 45 deletions

View File

@ -911,7 +911,7 @@ static int csk_wait_memory(struct chtls_dev *cdev,
struct sock *sk, long *timeo_p) struct sock *sk, long *timeo_p)
{ {
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
int err = 0; int ret, err = 0;
long current_timeo; long current_timeo;
long vm_wait = 0; long vm_wait = 0;
bool noblock; bool noblock;
@ -942,10 +942,13 @@ static int csk_wait_memory(struct chtls_dev *cdev,
set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
sk->sk_write_pending++; sk->sk_write_pending++;
sk_wait_event(sk, &current_timeo, sk->sk_err || ret = sk_wait_event(sk, &current_timeo, sk->sk_err ||
(sk->sk_shutdown & SEND_SHUTDOWN) || (sk->sk_shutdown & SEND_SHUTDOWN) ||
(csk_mem_free(cdev, sk) && !vm_wait), &wait); (csk_mem_free(cdev, sk) && !vm_wait),
&wait);
sk->sk_write_pending--; sk->sk_write_pending--;
if (ret < 0)
goto do_error;
if (vm_wait) { if (vm_wait) {
vm_wait -= current_timeo; vm_wait -= current_timeo;
@ -1348,6 +1351,7 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int copied = 0; int copied = 0;
int target; int target;
long timeo; long timeo;
int ret;
buffers_freed = 0; buffers_freed = 0;
@ -1423,7 +1427,11 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
if (copied >= target) if (copied >= target)
break; break;
chtls_cleanup_rbuf(sk, copied); chtls_cleanup_rbuf(sk, copied);
sk_wait_data(sk, &timeo, NULL); ret = sk_wait_data(sk, &timeo, NULL);
if (ret < 0) {
copied = copied ? : ret;
goto unlock;
}
continue; continue;
found_ok_skb: found_ok_skb:
if (!skb->len) { if (!skb->len) {
@ -1518,6 +1526,8 @@ skip_copy:
if (buffers_freed) if (buffers_freed)
chtls_cleanup_rbuf(sk, copied); chtls_cleanup_rbuf(sk, copied);
unlock:
release_sock(sk); release_sock(sk);
return copied; return copied;
} }
@ -1534,6 +1544,7 @@ static int peekmsg(struct sock *sk, struct msghdr *msg,
int copied = 0; int copied = 0;
size_t avail; /* amount of available data in current skb */ size_t avail; /* amount of available data in current skb */
long timeo; long timeo;
int ret;
lock_sock(sk); lock_sock(sk);
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
@ -1585,7 +1596,12 @@ static int peekmsg(struct sock *sk, struct msghdr *msg,
release_sock(sk); release_sock(sk);
lock_sock(sk); lock_sock(sk);
} else { } else {
sk_wait_data(sk, &timeo, NULL); ret = sk_wait_data(sk, &timeo, NULL);
if (ret < 0) {
/* here 'copied' is 0 due to previous checks */
copied = ret;
break;
}
} }
if (unlikely(peek_seq != tp->copied_seq)) { if (unlikely(peek_seq != tp->copied_seq)) {
@ -1656,6 +1672,7 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int copied = 0; int copied = 0;
long timeo; long timeo;
int target; /* Read at least this many bytes */ int target; /* Read at least this many bytes */
int ret;
buffers_freed = 0; buffers_freed = 0;
@ -1747,7 +1764,11 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
if (copied >= target) if (copied >= target)
break; break;
chtls_cleanup_rbuf(sk, copied); chtls_cleanup_rbuf(sk, copied);
sk_wait_data(sk, &timeo, NULL); ret = sk_wait_data(sk, &timeo, NULL);
if (ret < 0) {
copied = copied ? : ret;
goto unlock;
}
continue; continue;
found_ok_skb: found_ok_skb:
@ -1816,6 +1837,7 @@ skip_copy:
if (buffers_freed) if (buffers_freed)
chtls_cleanup_rbuf(sk, copied); chtls_cleanup_rbuf(sk, copied);
unlock:
release_sock(sk); release_sock(sk);
return copied; return copied;
} }

View File

@ -336,7 +336,7 @@ struct sk_filter;
* @sk_cgrp_data: cgroup data for this cgroup * @sk_cgrp_data: cgroup data for this cgroup
* @sk_memcg: this socket's memory cgroup association * @sk_memcg: this socket's memory cgroup association
* @sk_write_pending: a write to stream socket waits to start * @sk_write_pending: a write to stream socket waits to start
* @sk_wait_pending: number of threads blocked on this socket * @sk_disconnects: number of disconnect operations performed on this sock
* @sk_state_change: callback to indicate change in the state of the sock * @sk_state_change: callback to indicate change in the state of the sock
* @sk_data_ready: callback to indicate there is data to be processed * @sk_data_ready: callback to indicate there is data to be processed
* @sk_write_space: callback to indicate there is bf sending space available * @sk_write_space: callback to indicate there is bf sending space available
@ -429,7 +429,7 @@ struct sock {
unsigned int sk_napi_id; unsigned int sk_napi_id;
#endif #endif
int sk_rcvbuf; int sk_rcvbuf;
int sk_wait_pending; int sk_disconnects;
struct sk_filter __rcu *sk_filter; struct sk_filter __rcu *sk_filter;
union { union {
@ -1189,8 +1189,7 @@ static inline void sock_rps_reset_rxhash(struct sock *sk)
} }
#define sk_wait_event(__sk, __timeo, __condition, __wait) \ #define sk_wait_event(__sk, __timeo, __condition, __wait) \
({ int __rc; \ ({ int __rc, __dis = __sk->sk_disconnects; \
__sk->sk_wait_pending++; \
release_sock(__sk); \ release_sock(__sk); \
__rc = __condition; \ __rc = __condition; \
if (!__rc) { \ if (!__rc) { \
@ -1200,8 +1199,7 @@ static inline void sock_rps_reset_rxhash(struct sock *sk)
} \ } \
sched_annotate_sleep(); \ sched_annotate_sleep(); \
lock_sock(__sk); \ lock_sock(__sk); \
__sk->sk_wait_pending--; \ __rc = __dis == __sk->sk_disconnects ? __condition : -EPIPE; \
__rc = __condition; \
__rc; \ __rc; \
}) })

View File

@ -117,7 +117,7 @@ EXPORT_SYMBOL(sk_stream_wait_close);
*/ */
int sk_stream_wait_memory(struct sock *sk, long *timeo_p) int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
{ {
int err = 0; int ret, err = 0;
long vm_wait = 0; long vm_wait = 0;
long current_timeo = *timeo_p; long current_timeo = *timeo_p;
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
@ -142,11 +142,13 @@ int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
sk->sk_write_pending++; sk->sk_write_pending++;
sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) || ret = sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) ||
(READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) || (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) ||
(sk_stream_memory_free(sk) && (sk_stream_memory_free(sk) && !vm_wait),
!vm_wait), &wait); &wait);
sk->sk_write_pending--; sk->sk_write_pending--;
if (ret < 0)
goto do_error;
if (vm_wait) { if (vm_wait) {
vm_wait -= current_timeo; vm_wait -= current_timeo;

View File

@ -597,7 +597,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)
add_wait_queue(sk_sleep(sk), &wait); add_wait_queue(sk_sleep(sk), &wait);
sk->sk_write_pending += writebias; sk->sk_write_pending += writebias;
sk->sk_wait_pending++;
/* Basic assumption: if someone sets sk->sk_err, he _must_ /* Basic assumption: if someone sets sk->sk_err, he _must_
* change state of the socket from TCP_SYN_*. * change state of the socket from TCP_SYN_*.
@ -613,7 +612,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)
} }
remove_wait_queue(sk_sleep(sk), &wait); remove_wait_queue(sk_sleep(sk), &wait);
sk->sk_write_pending -= writebias; sk->sk_write_pending -= writebias;
sk->sk_wait_pending--;
return timeo; return timeo;
} }
@ -642,6 +640,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
return -EINVAL; return -EINVAL;
if (uaddr->sa_family == AF_UNSPEC) { if (uaddr->sa_family == AF_UNSPEC) {
sk->sk_disconnects++;
err = sk->sk_prot->disconnect(sk, flags); err = sk->sk_prot->disconnect(sk, flags);
sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED; sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED;
goto out; goto out;
@ -696,6 +695,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
int writebias = (sk->sk_protocol == IPPROTO_TCP) && int writebias = (sk->sk_protocol == IPPROTO_TCP) &&
tcp_sk(sk)->fastopen_req && tcp_sk(sk)->fastopen_req &&
tcp_sk(sk)->fastopen_req->data ? 1 : 0; tcp_sk(sk)->fastopen_req->data ? 1 : 0;
int dis = sk->sk_disconnects;
/* Error code is set above */ /* Error code is set above */
if (!timeo || !inet_wait_for_connect(sk, timeo, writebias)) if (!timeo || !inet_wait_for_connect(sk, timeo, writebias))
@ -704,6 +704,11 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
err = sock_intr_errno(timeo); err = sock_intr_errno(timeo);
if (signal_pending(current)) if (signal_pending(current))
goto out; goto out;
if (dis != sk->sk_disconnects) {
err = -EPIPE;
goto out;
}
} }
/* Connection was closed by RST, timeout, ICMP error /* Connection was closed by RST, timeout, ICMP error
@ -725,6 +730,7 @@ out:
sock_error: sock_error:
err = sock_error(sk) ? : -ECONNABORTED; err = sock_error(sk) ? : -ECONNABORTED;
sock->state = SS_UNCONNECTED; sock->state = SS_UNCONNECTED;
sk->sk_disconnects++;
if (sk->sk_prot->disconnect(sk, flags)) if (sk->sk_prot->disconnect(sk, flags))
sock->state = SS_DISCONNECTING; sock->state = SS_DISCONNECTING;
goto out; goto out;

View File

@ -1145,7 +1145,6 @@ struct sock *inet_csk_clone_lock(const struct sock *sk,
if (newsk) { if (newsk) {
struct inet_connection_sock *newicsk = inet_csk(newsk); struct inet_connection_sock *newicsk = inet_csk(newsk);
newsk->sk_wait_pending = 0;
inet_sk_set_state(newsk, TCP_SYN_RECV); inet_sk_set_state(newsk, TCP_SYN_RECV);
newicsk->icsk_bind_hash = NULL; newicsk->icsk_bind_hash = NULL;
newicsk->icsk_bind2_hash = NULL; newicsk->icsk_bind2_hash = NULL;

View File

@ -831,7 +831,9 @@ ssize_t tcp_splice_read(struct socket *sock, loff_t *ppos,
*/ */
if (!skb_queue_empty(&sk->sk_receive_queue)) if (!skb_queue_empty(&sk->sk_receive_queue))
break; break;
sk_wait_data(sk, &timeo, NULL); ret = sk_wait_data(sk, &timeo, NULL);
if (ret < 0)
break;
if (signal_pending(current)) { if (signal_pending(current)) {
ret = sock_intr_errno(timeo); ret = sock_intr_errno(timeo);
break; break;
@ -2442,7 +2444,11 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
__sk_flush_backlog(sk); __sk_flush_backlog(sk);
} else { } else {
tcp_cleanup_rbuf(sk, copied); tcp_cleanup_rbuf(sk, copied);
sk_wait_data(sk, &timeo, last); err = sk_wait_data(sk, &timeo, last);
if (err < 0) {
err = copied ? : err;
goto out;
}
} }
if ((flags & MSG_PEEK) && if ((flags & MSG_PEEK) &&
@ -2966,12 +2972,6 @@ int tcp_disconnect(struct sock *sk, int flags)
int old_state = sk->sk_state; int old_state = sk->sk_state;
u32 seq; u32 seq;
/* Deny disconnect if other threads are blocked in sk_wait_event()
* or inet_wait_for_connect().
*/
if (sk->sk_wait_pending)
return -EBUSY;
if (old_state != TCP_CLOSE) if (old_state != TCP_CLOSE)
tcp_set_state(sk, TCP_CLOSE); tcp_set_state(sk, TCP_CLOSE);

View File

@ -307,6 +307,8 @@ msg_bytes_ready:
} }
data = tcp_msg_wait_data(sk, psock, timeo); data = tcp_msg_wait_data(sk, psock, timeo);
if (data < 0)
return data;
if (data && !sk_psock_queue_empty(psock)) if (data && !sk_psock_queue_empty(psock))
goto msg_bytes_ready; goto msg_bytes_ready;
copied = -EAGAIN; copied = -EAGAIN;
@ -351,6 +353,8 @@ msg_bytes_ready:
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
data = tcp_msg_wait_data(sk, psock, timeo); data = tcp_msg_wait_data(sk, psock, timeo);
if (data < 0)
return data;
if (data) { if (data) {
if (!sk_psock_queue_empty(psock)) if (!sk_psock_queue_empty(psock))
goto msg_bytes_ready; goto msg_bytes_ready;

View File

@ -3098,12 +3098,6 @@ static int mptcp_disconnect(struct sock *sk, int flags)
{ {
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
/* Deny disconnect if other threads are blocked in sk_wait_event()
* or inet_wait_for_connect().
*/
if (sk->sk_wait_pending)
return -EBUSY;
/* We are on the fastopen error path. We can't call straight into the /* We are on the fastopen error path. We can't call straight into the
* subflows cleanup code due to lock nesting (we are already under * subflows cleanup code due to lock nesting (we are already under
* msk->firstsocket lock). * msk->firstsocket lock).
@ -3173,7 +3167,6 @@ struct sock *mptcp_sk_clone_init(const struct sock *sk,
inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk); inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
#endif #endif
nsk->sk_wait_pending = 0;
__mptcp_init_sock(nsk); __mptcp_init_sock(nsk);
msk = mptcp_sk(nsk); msk = mptcp_sk(nsk);

View File

@ -139,8 +139,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx)
int wait_on_pending_writer(struct sock *sk, long *timeo) int wait_on_pending_writer(struct sock *sk, long *timeo)
{ {
int rc = 0;
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret, rc = 0;
add_wait_queue(sk_sleep(sk), &wait); add_wait_queue(sk_sleep(sk), &wait);
while (1) { while (1) {
@ -154,9 +154,13 @@ int wait_on_pending_writer(struct sock *sk, long *timeo)
break; break;
} }
if (sk_wait_event(sk, timeo, ret = sk_wait_event(sk, timeo,
!READ_ONCE(sk->sk_write_pending), &wait)) !READ_ONCE(sk->sk_write_pending), &wait);
if (ret) {
if (ret < 0)
rc = ret;
break; break;
}
} }
remove_wait_queue(sk_sleep(sk), &wait); remove_wait_queue(sk_sleep(sk), &wait);
return rc; return rc;

View File

@ -1291,6 +1291,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
DEFINE_WAIT_FUNC(wait, woken_wake_function); DEFINE_WAIT_FUNC(wait, woken_wake_function);
int ret = 0;
long timeo; long timeo;
timeo = sock_rcvtimeo(sk, nonblock); timeo = sock_rcvtimeo(sk, nonblock);
@ -1302,6 +1303,9 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
if (sk->sk_err) if (sk->sk_err)
return sock_error(sk); return sock_error(sk);
if (ret < 0)
return ret;
if (!skb_queue_empty(&sk->sk_receive_queue)) { if (!skb_queue_empty(&sk->sk_receive_queue)) {
tls_strp_check_rcv(&ctx->strp); tls_strp_check_rcv(&ctx->strp);
if (tls_strp_msg_ready(ctx)) if (tls_strp_msg_ready(ctx))
@ -1320,10 +1324,10 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
released = true; released = true;
add_wait_queue(sk_sleep(sk), &wait); add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
sk_wait_event(sk, &timeo, ret = sk_wait_event(sk, &timeo,
tls_strp_msg_ready(ctx) || tls_strp_msg_ready(ctx) ||
!sk_psock_queue_empty(psock), !sk_psock_queue_empty(psock),
&wait); &wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait); remove_wait_queue(sk_sleep(sk), &wait);
@ -1852,6 +1856,7 @@ static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx,
bool nonblock) bool nonblock)
{ {
long timeo; long timeo;
int ret;
timeo = sock_rcvtimeo(sk, nonblock); timeo = sock_rcvtimeo(sk, nonblock);
@ -1861,14 +1866,16 @@ static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx,
ctx->reader_contended = 1; ctx->reader_contended = 1;
add_wait_queue(&ctx->wq, &wait); add_wait_queue(&ctx->wq, &wait);
sk_wait_event(sk, &timeo, ret = sk_wait_event(sk, &timeo,
!READ_ONCE(ctx->reader_present), &wait); !READ_ONCE(ctx->reader_present), &wait);
remove_wait_queue(&ctx->wq, &wait); remove_wait_queue(&ctx->wq, &wait);
if (timeo <= 0) if (timeo <= 0)
return -EAGAIN; return -EAGAIN;
if (signal_pending(current)) if (signal_pending(current))
return sock_intr_errno(timeo); return sock_intr_errno(timeo);
if (ret < 0)
return ret;
} }
WRITE_ONCE(ctx->reader_present, 1); WRITE_ONCE(ctx->reader_present, 1);