diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c index cca6ee716020..2ed29e40c6a9 100644 --- a/net/sunrpc/svcsock.c +++ b/net/sunrpc/svcsock.c @@ -43,7 +43,6 @@ #include #include #include -#include #include #include #include @@ -227,27 +226,30 @@ static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) } static int -svc_tcp_sock_process_cmsg(struct svc_sock *svsk, struct msghdr *msg, +svc_tcp_sock_process_cmsg(struct socket *sock, struct msghdr *msg, struct cmsghdr *cmsg, int ret) { - if (cmsg->cmsg_level == SOL_TLS && - cmsg->cmsg_type == TLS_GET_RECORD_TYPE) { - u8 content_type = *((u8 *)CMSG_DATA(cmsg)); + u8 content_type = tls_get_record_type(sock->sk, cmsg); + u8 level, description; - switch (content_type) { - case TLS_RECORD_TYPE_DATA: - /* TLS sets EOR at the end of each application data - * record, even though there might be more frames - * waiting to be decrypted. - */ - msg->msg_flags &= ~MSG_EOR; - break; - case TLS_RECORD_TYPE_ALERT: - ret = -ENOTCONN; - break; - default: - ret = -EAGAIN; - } + switch (content_type) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + msg->msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(sock->sk, msg, &level, &description); + ret = (level == TLS_ALERT_LEVEL_FATAL) ? + -ENOTCONN : -EAGAIN; + break; + default: + /* discard this record type */ + ret = -EAGAIN; } return ret; } @@ -259,13 +261,14 @@ svc_tcp_sock_recv_cmsg(struct svc_sock *svsk, struct msghdr *msg) struct cmsghdr cmsg; u8 buf[CMSG_SPACE(sizeof(u8))]; } u; + struct socket *sock = svsk->sk_sock; int ret; msg->msg_control = &u; msg->msg_controllen = sizeof(u); - ret = sock_recvmsg(svsk->sk_sock, msg, MSG_DONTWAIT); + ret = sock_recvmsg(sock, msg, MSG_DONTWAIT); if (unlikely(msg->msg_controllen != sizeof(u))) - ret = svc_tcp_sock_process_cmsg(svsk, msg, &u.cmsg, ret); + ret = svc_tcp_sock_process_cmsg(sock, msg, &u.cmsg, ret); return ret; } diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c index 5096aa62de5c..268a2cc61acd 100644 --- a/net/sunrpc/xprtsock.c +++ b/net/sunrpc/xprtsock.c @@ -47,7 +47,6 @@ #include #include #include -#include #include #include @@ -361,24 +360,27 @@ static int xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, struct cmsghdr *cmsg, int ret) { - if (cmsg->cmsg_level == SOL_TLS && - cmsg->cmsg_type == TLS_GET_RECORD_TYPE) { - u8 content_type = *((u8 *)CMSG_DATA(cmsg)); + u8 content_type = tls_get_record_type(sock->sk, cmsg); + u8 level, description; - switch (content_type) { - case TLS_RECORD_TYPE_DATA: - /* TLS sets EOR at the end of each application data - * record, even though there might be more frames - * waiting to be decrypted. - */ - msg->msg_flags &= ~MSG_EOR; - break; - case TLS_RECORD_TYPE_ALERT: - ret = -ENOTCONN; - break; - default: - ret = -EAGAIN; - } + switch (content_type) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + /* TLS sets EOR at the end of each application data + * record, even though there might be more frames + * waiting to be decrypted. + */ + msg->msg_flags &= ~MSG_EOR; + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(sock->sk, msg, &level, &description); + ret = (level == TLS_ALERT_LEVEL_FATAL) ? + -EACCES : -EAGAIN; + break; + default: + /* discard this record type */ + ret = -EAGAIN; } return ret; } @@ -778,6 +780,8 @@ static void xs_stream_data_receive(struct sock_xprt *transport) } if (ret == -ESHUTDOWN) kernel_sock_shutdown(transport->sock, SHUT_RDWR); + else if (ret == -EACCES) + xprt_wake_pending_tasks(&transport->xprt, -EACCES); else xs_poll_check_readable(transport); out: