diff --git a/include/net/tls.h b/include/net/tls.h index 154949c7b0c8..c36bf4c50027 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -124,6 +124,7 @@ struct tls_strparser { u32 mark : 8; u32 stopped : 1; u32 copy_mode : 1; + u32 mixed_decrypted : 1; u32 msg_ready : 1; struct strp_msg stm; diff --git a/net/tls/tls.h b/net/tls/tls.h index 804c3880d028..0672acab2773 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -167,6 +167,11 @@ static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx) return ctx->strp.msg_ready; } +static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx) +{ + return ctx->strp.mixed_decrypted; +} + #ifdef CONFIG_TLS_DEVICE int tls_device_init(void); void tls_device_cleanup(void); diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 3b87c7b04ac8..bf69c9d6d06c 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -1007,20 +1007,14 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx) struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx); struct sk_buff *skb = tls_strp_msg(sw_ctx); struct strp_msg *rxm = strp_msg(skb); - int is_decrypted = skb->decrypted; - int is_encrypted = !is_decrypted; - struct sk_buff *skb_iter; - int left; + int is_decrypted, is_encrypted; - left = rxm->full_len + rxm->offset - skb_pagelen(skb); - /* Check if all the data is decrypted already */ - skb_iter = skb_shinfo(skb)->frag_list; - while (skb_iter && left > 0) { - is_decrypted &= skb_iter->decrypted; - is_encrypted &= !skb_iter->decrypted; - - left -= skb_iter->len; - skb_iter = skb_iter->next; + if (!tls_strp_msg_mixed_decrypted(sw_ctx)) { + is_decrypted = skb->decrypted; + is_encrypted = !is_decrypted; + } else { + is_decrypted = 0; + is_encrypted = 0; } trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len, diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c index 61fbf84baf9e..da95abbb7ea3 100644 --- a/net/tls/tls_strp.c +++ b/net/tls/tls_strp.c @@ -29,7 +29,8 @@ static void tls_strp_anchor_free(struct tls_strparser *strp) struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); - shinfo->frag_list = NULL; + if (!strp->copy_mode) + shinfo->frag_list = NULL; consume_skb(strp->anchor); strp->anchor = NULL; } @@ -195,22 +196,22 @@ static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) for (i = 0; i < shinfo->nr_frags; i++) __skb_frag_unref(&shinfo->frags[i], false); shinfo->nr_frags = 0; + if (strp->copy_mode) { + kfree_skb_list(shinfo->frag_list); + shinfo->frag_list = NULL; + } strp->copy_mode = 0; + strp->mixed_decrypted = 0; } -static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, - unsigned int offset, size_t in_len) +static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb, + struct sk_buff *in_skb, unsigned int offset, + size_t in_len) { - struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; - struct sk_buff *skb; - skb_frag_t *frag; size_t len, chunk; + skb_frag_t *frag; int sz; - if (strp->msg_ready) - return 0; - - skb = strp->anchor; frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; len = in_len; @@ -228,10 +229,8 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, skb_frag_size_add(frag, chunk); sz = tls_rx_msg_size(strp, skb); - if (sz < 0) { - desc->error = sz; - return 0; - } + if (sz < 0) + return sz; /* We may have over-read, sz == 0 is guaranteed under-read */ if (unlikely(sz && sz < skb->len)) { @@ -271,15 +270,99 @@ static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, offset += chunk; } - if (strp->stm.full_len == skb->len) { +read_done: + return in_len - len; +} + +static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb, + struct sk_buff *in_skb, unsigned int offset, + size_t in_len) +{ + struct sk_buff *nskb, *first, *last; + struct skb_shared_info *shinfo; + size_t chunk; + int sz; + + if (strp->stm.full_len) + chunk = strp->stm.full_len - skb->len; + else + chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; + chunk = min(chunk, in_len); + + nskb = tls_strp_skb_copy(strp, in_skb, offset, chunk); + if (!nskb) + return -ENOMEM; + + shinfo = skb_shinfo(skb); + if (!shinfo->frag_list) { + shinfo->frag_list = nskb; + nskb->prev = nskb; + } else { + first = shinfo->frag_list; + last = first->prev; + last->next = nskb; + first->prev = nskb; + } + + skb->len += chunk; + skb->data_len += chunk; + + if (!strp->stm.full_len) { + sz = tls_rx_msg_size(strp, skb); + if (sz < 0) + return sz; + + /* We may have over-read, sz == 0 is guaranteed under-read */ + if (unlikely(sz && sz < skb->len)) { + int over = skb->len - sz; + + WARN_ON_ONCE(over > chunk); + skb->len -= over; + skb->data_len -= over; + __pskb_trim(nskb, nskb->len - over); + + chunk -= over; + } + + strp->stm.full_len = sz; + } + + return chunk; +} + +static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, + unsigned int offset, size_t in_len) +{ + struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; + struct sk_buff *skb; + int ret; + + if (strp->msg_ready) + return 0; + + skb = strp->anchor; + if (!skb->len) + skb_copy_decrypted(skb, in_skb); + else + strp->mixed_decrypted |= !!skb_cmp_decrypted(skb, in_skb); + + if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted) + ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len); + else + ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len); + if (ret < 0) { + desc->error = ret; + ret = 0; + } + + if (strp->stm.full_len && strp->stm.full_len == skb->len) { desc->count = 0; strp->msg_ready = 1; tls_rx_msg_ready(strp); } -read_done: - return in_len - len; + return ret; } static int tls_strp_read_copyin(struct tls_strparser *strp)