diff --git a/include/net/af_unix.h b/include/net/af_unix.h index dc7469191195..414463803b7e 100644 --- a/include/net/af_unix.h +++ b/include/net/af_unix.h @@ -24,6 +24,7 @@ void unix_inflight(struct user_struct *user, struct file *fp); void unix_notinflight(struct user_struct *user, struct file *fp); void unix_add_edges(struct scm_fp_list *fpl, struct unix_sock *receiver); void unix_del_edges(struct scm_fp_list *fpl); +void unix_update_edges(struct unix_sock *receiver); int unix_prepare_fpl(struct scm_fp_list *fpl); void unix_destroy_fpl(struct scm_fp_list *fpl); void unix_gc(void); diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index af74e7ebc35a..ae77e2dc0dae 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -1721,7 +1721,7 @@ static int unix_accept(struct socket *sock, struct socket *newsock, int flags, } tsk = skb->sk; - unix_sk(tsk)->listener = NULL; + unix_update_edges(unix_sk(tsk)); skb_free_datagram(sk, skb); wake_up_interruptible(&unix_sk(sk)->peer_wait); diff --git a/net/unix/garbage.c b/net/unix/garbage.c index 33aadaa35346..8d0912c1d01a 100644 --- a/net/unix/garbage.c +++ b/net/unix/garbage.c @@ -101,6 +101,17 @@ struct unix_sock *unix_get_socket(struct file *filp) return NULL; } +static struct unix_vertex *unix_edge_successor(struct unix_edge *edge) +{ + /* If an embryo socket has a fd, + * the listener indirectly holds the fd's refcnt. + */ + if (edge->successor->listener) + return unix_sk(edge->successor->listener)->vertex; + + return edge->successor->vertex; +} + static LIST_HEAD(unix_unvisited_vertices); enum unix_vertex_index { @@ -209,6 +220,13 @@ out: fpl->inflight = false; } +void unix_update_edges(struct unix_sock *receiver) +{ + spin_lock(&unix_gc_lock); + receiver->listener = NULL; + spin_unlock(&unix_gc_lock); +} + int unix_prepare_fpl(struct scm_fp_list *fpl) { struct unix_vertex *vertex; @@ -268,7 +286,7 @@ next_vertex: /* Explore neighbour vertices (receivers of the current vertex's fd). */ list_for_each_entry(edge, &vertex->edges, vertex_entry) { - struct unix_vertex *next_vertex = edge->successor->vertex; + struct unix_vertex *next_vertex = unix_edge_successor(edge); if (!next_vertex) continue;