diff --git a/net/mptcp/pm.c b/net/mptcp/pm.c index e63e14f4cf2a..09d6e736161d 100644 --- a/net/mptcp/pm.c +++ b/net/mptcp/pm.c @@ -20,6 +20,11 @@ int mptcp_pm_announce_addr(struct mptcp_sock *msk, pr_debug("msk=%p, local_id=%d", msk, addr->id); + if (add_addr) { + pr_warn("addr_signal error, add_addr=%d", add_addr); + return -EINVAL; + } + msk->pm.local = *addr; add_addr |= BIT(MPTCP_ADD_ADDR_SIGNAL); if (echo) @@ -34,10 +39,18 @@ int mptcp_pm_announce_addr(struct mptcp_sock *msk, int mptcp_pm_remove_addr(struct mptcp_sock *msk, u8 local_id) { + u8 rm_addr = READ_ONCE(msk->pm.add_addr_signal); + pr_debug("msk=%p, local_id=%d", msk, local_id); + if (rm_addr) { + pr_warn("addr_signal error, rm_addr=%d", rm_addr); + return -EINVAL; + } + msk->pm.rm_id = local_id; - WRITE_ONCE(msk->pm.rm_addr_signal, true); + rm_addr |= BIT(MPTCP_RM_ADDR_SIGNAL); + WRITE_ONCE(msk->pm.add_addr_signal, rm_addr); return 0; } @@ -231,7 +244,7 @@ bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining, goto out_unlock; *rm_id = msk->pm.rm_id; - WRITE_ONCE(msk->pm.rm_addr_signal, false); + WRITE_ONCE(msk->pm.add_addr_signal, 0); ret = true; out_unlock: @@ -253,7 +266,6 @@ void mptcp_pm_data_init(struct mptcp_sock *msk) msk->pm.rm_id = 0; WRITE_ONCE(msk->pm.work_pending, false); WRITE_ONCE(msk->pm.add_addr_signal, 0); - WRITE_ONCE(msk->pm.rm_addr_signal, false); WRITE_ONCE(msk->pm.accept_addr, false); WRITE_ONCE(msk->pm.accept_subflow, false); msk->pm.status = 0; diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h index e880fa802cdf..f002c12beb98 100644 --- a/net/mptcp/protocol.h +++ b/net/mptcp/protocol.h @@ -173,6 +173,7 @@ enum mptcp_add_addr_status { MPTCP_ADD_ADDR_ECHO, MPTCP_ADD_ADDR_IPV6, MPTCP_ADD_ADDR_PORT, + MPTCP_RM_ADDR_SIGNAL, }; struct mptcp_pm_data { @@ -183,7 +184,6 @@ struct mptcp_pm_data { spinlock_t lock; /*protects the whole PM data */ u8 add_addr_signal; - bool rm_addr_signal; bool server_side; bool work_pending; bool accept_addr; @@ -578,7 +578,7 @@ static inline bool mptcp_pm_should_add_signal_port(struct mptcp_sock *msk) static inline bool mptcp_pm_should_rm_signal(struct mptcp_sock *msk) { - return READ_ONCE(msk->pm.rm_addr_signal); + return READ_ONCE(msk->pm.add_addr_signal) & BIT(MPTCP_RM_ADDR_SIGNAL); } static inline unsigned int mptcp_add_addr_len(int family, bool echo, bool port)