Merge branch 'tls-rx-refactoring-part-2'

Jakub Kicinski says:

====================
tls: rx: random refactoring part 2

TLS Rx refactoring. Part 2 of 3. This one focusing on the main loop.
A couple of features to follow.
====================
This commit is contained in:
David S. Miller 2022-04-10 17:32:12 +01:00
commit 516a2f1f6f
2 changed files with 101 additions and 158 deletions

View file

@ -152,7 +152,6 @@ struct tls_sw_context_rx {
atomic_t decrypt_pending;
/* protect crypto_wait with decrypt_pending*/
spinlock_t decrypt_compl_lock;
bool async_notify;
};
struct tls_record_info {

View file

@ -44,6 +44,11 @@
#include <net/strparser.h>
#include <net/tls.h>
struct tls_decrypt_arg {
bool zc;
bool async;
};
noinline void tls_err_abort(struct sock *sk, int err)
{
WARN_ON_ONCE(err >= 0);
@ -168,7 +173,6 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
struct scatterlist *sg;
struct sk_buff *skb;
unsigned int pages;
int pending;
skb = (struct sk_buff *)req->data;
tls_ctx = tls_get_ctx(skb->sk);
@ -216,9 +220,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
kfree(aead_req);
spin_lock_bh(&ctx->decrypt_compl_lock);
pending = atomic_dec_return(&ctx->decrypt_pending);
if (!pending && ctx->async_notify)
if (!atomic_dec_return(&ctx->decrypt_pending))
complete(&ctx->async_wait.completion);
spin_unlock_bh(&ctx->decrypt_compl_lock);
}
@ -1345,15 +1347,14 @@ static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
return skb;
}
static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
static int tls_setup_from_iter(struct iov_iter *from,
int length, int *pages_used,
unsigned int *size_used,
struct scatterlist *to,
int to_max_pages)
{
int rc = 0, i = 0, num_elem = *pages_used, maxpages;
struct page *pages[MAX_SKB_FRAGS];
unsigned int size = *size_used;
unsigned int size = 0;
ssize_t copied, use;
size_t offset;
@ -1396,8 +1397,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
sg_mark_end(&to[num_elem - 1]);
out:
if (rc)
iov_iter_revert(from, size - *size_used);
*size_used = size;
iov_iter_revert(from, size);
*pages_used = num_elem;
return rc;
@ -1414,7 +1414,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
struct iov_iter *out_iov,
struct scatterlist *out_sg,
int *chunk, bool *zc, bool async)
struct tls_decrypt_arg *darg)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@ -1431,7 +1431,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
prot->tail_size;
int iv_offset = 0;
if (*zc && (out_iov || out_sg)) {
if (darg->zc && (out_iov || out_sg)) {
if (out_iov)
n_sgout = 1 +
iov_iter_npages_cap(out_iov, INT_MAX, data_len);
@ -1441,7 +1441,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
rxm->full_len - prot->prepend_size);
} else {
n_sgout = 0;
*zc = false;
darg->zc = false;
n_sgin = skb_cow_data(skb, 0, &unused);
}
@ -1523,9 +1523,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
sg_init_table(sgout, n_sgout);
sg_set_buf(&sgout[0], aad, prot->aad_size);
*chunk = 0;
err = tls_setup_from_iter(sk, out_iov, data_len,
&pages, chunk, &sgout[1],
err = tls_setup_from_iter(out_iov, data_len,
&pages, &sgout[1],
(n_sgout - 1));
if (err < 0)
goto fallback_to_reg_recv;
@ -1538,13 +1537,12 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
fallback_to_reg_recv:
sgout = sgin;
pages = 0;
*chunk = data_len;
*zc = false;
darg->zc = false;
}
/* Prepare and submit AEAD request */
err = tls_do_decryption(sk, skb, sgin, sgout, iv,
data_len, aead_req, async);
data_len, aead_req, darg->async);
if (err == -EINPROGRESS)
return err;
@ -1557,8 +1555,8 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
}
static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
struct iov_iter *dest, int *chunk, bool *zc,
bool async)
struct iov_iter *dest,
struct tls_decrypt_arg *darg)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
@ -1567,7 +1565,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
int pad, err;
if (tlm->decrypted) {
*zc = false;
darg->zc = false;
return 0;
}
@ -1577,12 +1575,12 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
return err;
if (err > 0) {
tlm->decrypted = 1;
*zc = false;
darg->zc = false;
goto decrypt_done;
}
}
err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
err = decrypt_internal(sk, skb, dest, NULL, darg);
if (err < 0) {
if (err == -EINPROGRESS)
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
@ -1608,34 +1606,32 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout)
{
bool zc = true;
int chunk;
struct tls_decrypt_arg darg = { .zc = true, };
return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
return decrypt_internal(sk, skb, NULL, sgout, &darg);
}
static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
unsigned int len)
static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
u8 *control)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
int err;
if (skb) {
struct strp_msg *rxm = strp_msg(skb);
if (!*control) {
*control = tlm->control;
if (!*control)
return -EBADMSG;
if (len < rxm->full_len) {
rxm->offset += len;
rxm->full_len -= len;
return false;
err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
sizeof(*control), control);
if (*control != TLS_RECORD_TYPE_DATA) {
if (err || msg->msg_flags & MSG_CTRUNC)
return -EIO;
}
consume_skb(skb);
} else if (*control != tlm->control) {
return 0;
}
/* Finished with message */
ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp);
return true;
return 1;
}
/* This function traverses the rx_list in tls receive context to copies the
@ -1646,31 +1642,23 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
static int process_rx_list(struct tls_sw_context_rx *ctx,
struct msghdr *msg,
u8 *control,
bool *cmsg,
size_t skip,
size_t len,
bool zc,
bool is_peek)
{
struct sk_buff *skb = skb_peek(&ctx->rx_list);
u8 ctrl = *control;
u8 msgc = *cmsg;
struct tls_msg *tlm;
ssize_t copied = 0;
/* Set the record type in 'control' if caller didn't pass it */
if (!ctrl && skb) {
tlm = tls_msg(skb);
ctrl = tlm->control;
}
int err;
while (skip && skb) {
struct strp_msg *rxm = strp_msg(skb);
tlm = tls_msg(skb);
/* Cannot process a record of different type */
if (ctrl != tlm->control)
return 0;
err = tls_record_content_type(msg, tlm, control);
if (err <= 0)
return err;
if (skip < rxm->full_len)
break;
@ -1686,27 +1674,12 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
tlm = tls_msg(skb);
/* Cannot process a record of different type */
if (ctrl != tlm->control)
return 0;
/* Set record type if not already done. For a non-data record,
* do not proceed if record type could not be copied.
*/
if (!msgc) {
int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
sizeof(ctrl), &ctrl);
msgc = true;
if (ctrl != TLS_RECORD_TYPE_DATA) {
if (cerr || msg->msg_flags & MSG_CTRUNC)
return -EIO;
*cmsg = msgc;
}
}
err = tls_record_content_type(msg, tlm, control);
if (err <= 0)
return err;
if (!zc || (rxm->full_len - skip) > len) {
int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
err = skb_copy_datagram_msg(skb, rxm->offset + skip,
msg, chunk);
if (err < 0)
return err;
@ -1743,7 +1716,6 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
skb = next_skb;
}
*control = ctrl;
return copied;
}
@ -1758,19 +1730,19 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct sk_psock *psock;
int num_async, pending;
unsigned char control = 0;
ssize_t decrypted = 0;
struct strp_msg *rxm;
struct tls_msg *tlm;
struct sk_buff *skb;
ssize_t copied = 0;
bool cmsg = false;
bool async = false;
int target, err = 0;
long timeo;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool is_peek = flags & MSG_PEEK;
bool bpf_strp_enabled;
bool zc_capable;
flags |= nonblock;
@ -1782,8 +1754,7 @@ int tls_sw_recvmsg(struct sock *sk,
bpf_strp_enabled = sk_psock_strp_enabled(psock);
/* Process pending decrypted records. It must be non-zero-copy */
err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
is_peek);
err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek);
if (err < 0) {
tls_err_abort(sk, err);
goto end;
@ -1797,15 +1768,12 @@ int tls_sw_recvmsg(struct sock *sk,
len = len - copied;
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
prot->version != TLS_1_3_VERSION;
decrypted = 0;
num_async = 0;
while (len && (decrypted + copied < target || ctx->recv_pkt)) {
bool retain_skb = false;
bool zc = false;
int to_decrypt;
int chunk = 0;
bool async_capable;
bool async = false;
struct tls_decrypt_arg darg = {};
int to_decrypt, chunk;
skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
if (!skb) {
@ -1827,29 +1795,24 @@ int tls_sw_recvmsg(struct sock *sk,
to_decrypt = rxm->full_len - prot->overhead_size;
if (to_decrypt <= len && !is_kvec && !is_peek &&
tlm->control == TLS_RECORD_TYPE_DATA &&
prot->version != TLS_1_3_VERSION &&
!bpf_strp_enabled)
zc = true;
if (zc_capable && to_decrypt <= len &&
tlm->control == TLS_RECORD_TYPE_DATA)
darg.zc = true;
/* Do not use async mode if record is non-data */
if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
async_capable = ctx->async_capable;
darg.async = ctx->async_capable;
else
async_capable = false;
darg.async = false;
err = decrypt_skb_update(sk, skb, &msg->msg_iter,
&chunk, &zc, async_capable);
err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg);
if (err < 0 && err != -EINPROGRESS) {
tls_err_abort(sk, -EBADMSG);
goto recv_end;
}
if (err == -EINPROGRESS) {
if (err == -EINPROGRESS)
async = true;
num_async++;
}
/* If the type of records being processed is not known yet,
* set it to record type just dequeued. If it is already known,
@ -1858,92 +1821,79 @@ int tls_sw_recvmsg(struct sock *sk,
* is known just after record is dequeued from stream parser.
* For tls1.3, we disable async.
*/
if (!control)
control = tlm->control;
else if (control != tlm->control)
err = tls_record_content_type(msg, tlm, &control);
if (err <= 0)
goto recv_end;
if (!cmsg) {
int cerr;
ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp);
skb_queue_tail(&ctx->rx_list, skb);
cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
sizeof(control), &control);
cmsg = true;
if (control != TLS_RECORD_TYPE_DATA) {
if (cerr || msg->msg_flags & MSG_CTRUNC) {
err = -EIO;
goto recv_end;
}
}
if (async) {
/* TLS 1.2-only, to_decrypt must be text length */
chunk = min_t(int, to_decrypt, len);
leave_on_list:
decrypted += chunk;
len -= chunk;
continue;
}
/* TLS 1.3 may have updated the length by more than overhead */
chunk = rxm->full_len;
if (async)
goto pick_next_record;
if (!darg.zc) {
bool partially_consumed = chunk > len;
if (!zc) {
if (bpf_strp_enabled) {
err = sk_psock_tls_strp_read(psock, skb);
if (err != __SK_PASS) {
rxm->offset = rxm->offset + rxm->full_len;
rxm->full_len = 0;
skb_unlink(skb, &ctx->rx_list);
if (err == __SK_DROP)
consume_skb(skb);
ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp);
continue;
}
}
if (rxm->full_len > len) {
retain_skb = true;
if (partially_consumed)
chunk = len;
} else {
chunk = rxm->full_len;
}
err = skb_copy_datagram_msg(skb, rxm->offset,
msg, chunk);
if (err < 0)
goto recv_end;
if (!is_peek) {
rxm->offset = rxm->offset + chunk;
rxm->full_len = rxm->full_len - chunk;
if (is_peek)
goto leave_on_list;
if (partially_consumed) {
rxm->offset += chunk;
rxm->full_len -= chunk;
goto leave_on_list;
}
}
pick_next_record:
if (chunk > len)
chunk = len;
decrypted += chunk;
len -= chunk;
/* For async or peek case, queue the current skb */
if (async || is_peek || retain_skb) {
skb_queue_tail(&ctx->rx_list, skb);
skb = NULL;
}
skb_unlink(skb, &ctx->rx_list);
consume_skb(skb);
if (tls_sw_advance_skb(sk, skb, chunk)) {
/* Return full control message to
* userspace before trying to parse
* another message type
*/
msg->msg_flags |= MSG_EOR;
if (control != TLS_RECORD_TYPE_DATA)
goto recv_end;
} else {
/* Return full control message to userspace before trying
* to parse another message type
*/
msg->msg_flags |= MSG_EOR;
if (control != TLS_RECORD_TYPE_DATA)
break;
}
}
recv_end:
if (num_async) {
if (async) {
int pending;
/* Wait for all previously submitted records to be decrypted */
spin_lock_bh(&ctx->decrypt_compl_lock);
ctx->async_notify = true;
reinit_completion(&ctx->async_wait.completion);
pending = atomic_read(&ctx->decrypt_pending);
spin_unlock_bh(&ctx->decrypt_compl_lock);
if (pending) {
@ -1955,21 +1905,14 @@ int tls_sw_recvmsg(struct sock *sk,
decrypted = 0;
goto end;
}
} else {
reinit_completion(&ctx->async_wait.completion);
}
/* There can be no concurrent accesses, since we have no
* pending decrypt operations
*/
WRITE_ONCE(ctx->async_notify, false);
/* Drain records from the rx_list & copy if required */
if (is_peek || is_kvec)
err = process_rx_list(ctx, msg, &control, &cmsg, copied,
err = process_rx_list(ctx, msg, &control, copied,
decrypted, false, is_peek);
else
err = process_rx_list(ctx, msg, &control, &cmsg, 0,
err = process_rx_list(ctx, msg, &control, 0,
decrypted, true, is_peek);
if (err < 0) {
tls_err_abort(sk, err);
@ -2003,7 +1946,6 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
int err = 0;
long timeo;
int chunk;
bool zc = false;
lock_sock(sk);
@ -2013,12 +1955,14 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
if (from_queue) {
skb = __skb_dequeue(&ctx->rx_list);
} else {
struct tls_decrypt_arg darg = {};
skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
&err);
if (!skb)
goto splice_read_end;
err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
err = decrypt_skb_update(sk, skb, NULL, &darg);
if (err < 0) {
tls_err_abort(sk, -EBADMSG);
goto splice_read_end;