diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c index 5692daf57a4d..ae36155ff128 100644 --- a/net/mptcp/pm_netlink.c +++ b/net/mptcp/pm_netlink.c @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -1005,8 +1006,8 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk, bool is_ipv6 = sk->sk_family == AF_INET6; int addrlen = sizeof(struct sockaddr_in); struct sockaddr_storage addr; + struct sock *newsk, *ssk; struct socket *ssock; - struct sock *newsk; int backlog = 1024; int err; @@ -1042,18 +1043,23 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk, if (entry->addr.family == AF_INET6) addrlen = sizeof(struct sockaddr_in6); #endif - err = kernel_bind(ssock, (struct sockaddr *)&addr, addrlen); + ssk = mptcp_sk(newsk)->first; + if (ssk->sk_family == AF_INET) + err = inet_bind_sk(ssk, (struct sockaddr *)&addr, addrlen); +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (ssk->sk_family == AF_INET6) + err = inet6_bind_sk(ssk, (struct sockaddr *)&addr, addrlen); +#endif if (err) return err; inet_sk_state_store(newsk, TCP_LISTEN); - err = kernel_listen(ssock, backlog); - if (err) - return err; - - mptcp_event_pm_listener(ssock->sk, MPTCP_EVENT_LISTENER_CREATED); - - return 0; + lock_sock(ssk); + err = __inet_listen_sk(ssk, backlog); + if (!err) + mptcp_event_pm_listener(ssk, MPTCP_EVENT_LISTENER_CREATED); + release_sock(ssk); + return err; } int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc)