diff --git a/net/unix/diag.c b/net/unix/diag.c index 105f522a89fe..616b55c5b890 100644 --- a/net/unix/diag.c +++ b/net/unix/diag.c @@ -114,14 +114,16 @@ static int sk_diag_show_rqlen(struct sock *sk, struct sk_buff *nlskb) return nla_put(nlskb, UNIX_DIAG_RQLEN, sizeof(rql), &rql); } -static int sk_diag_dump_uid(struct sock *sk, struct sk_buff *nlskb) +static int sk_diag_dump_uid(struct sock *sk, struct sk_buff *nlskb, + struct user_namespace *user_ns) { - uid_t uid = from_kuid_munged(sk_user_ns(nlskb->sk), sock_i_uid(sk)); + uid_t uid = from_kuid_munged(user_ns, sock_i_uid(sk)); return nla_put(nlskb, UNIX_DIAG_UID, sizeof(uid_t), &uid); } static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, struct unix_diag_req *req, - u32 portid, u32 seq, u32 flags, int sk_ino) + struct user_namespace *user_ns, + u32 portid, u32 seq, u32 flags, int sk_ino) { struct nlmsghdr *nlh; struct unix_diag_msg *rep; @@ -167,7 +169,7 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, struct unix_diag_r goto out_nlmsg_trim; if ((req->udiag_show & UDIAG_SHOW_UID) && - sk_diag_dump_uid(sk, skb)) + sk_diag_dump_uid(sk, skb, user_ns)) goto out_nlmsg_trim; nlmsg_end(skb, nlh); @@ -179,7 +181,8 @@ out_nlmsg_trim: } static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct unix_diag_req *req, - u32 portid, u32 seq, u32 flags) + struct user_namespace *user_ns, + u32 portid, u32 seq, u32 flags) { int sk_ino; @@ -190,7 +193,7 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct unix_diag_r if (!sk_ino) return 0; - return sk_diag_fill(sk, skb, req, portid, seq, flags, sk_ino); + return sk_diag_fill(sk, skb, req, user_ns, portid, seq, flags, sk_ino); } static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) @@ -214,7 +217,7 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) goto next; if (!(req->udiag_states & (1 << sk->sk_state))) goto next; - if (sk_diag_dump(sk, skb, req, + if (sk_diag_dump(sk, skb, req, sk_user_ns(skb->sk), NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, NLM_F_MULTI) < 0) { @@ -282,7 +285,8 @@ again: if (!rep) goto out; - err = sk_diag_fill(sk, rep, req, NETLINK_CB(in_skb).portid, + err = sk_diag_fill(sk, rep, req, sk_user_ns(NETLINK_CB(in_skb).sk), + NETLINK_CB(in_skb).portid, nlh->nlmsg_seq, 0, req->udiag_ino); if (err < 0) { nlmsg_free(rep);