diff --git a/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c b/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c index bcdc7fc2f427..6482728794dd 100644 --- a/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c +++ b/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.c @@ -361,9 +361,7 @@ static void chcr_ktls_dev_del(struct net_device *netdev, struct tls_context *tls_ctx, enum tls_offload_ctx_dir direction) { - struct chcr_ktls_ofld_ctx_tx *tx_ctx = - chcr_get_ktls_tx_context(tls_ctx); - struct chcr_ktls_info *tx_info = tx_ctx->chcr_info; + struct chcr_ktls_info *tx_info = chcr_get_ktls_tx_info(tls_ctx); struct ch_ktls_port_stats_debug *port_stats; struct chcr_ktls_uld_ctx *u_ctx; @@ -396,7 +394,7 @@ static void chcr_ktls_dev_del(struct net_device *netdev, port_stats = &tx_info->adap->ch_ktls_stats.ktls_port[tx_info->port_id]; atomic64_inc(&port_stats->ktls_tx_connection_close); kvfree(tx_info); - tx_ctx->chcr_info = NULL; + chcr_set_ktls_tx_info(tls_ctx, NULL); /* release module refcount */ module_put(THIS_MODULE); } @@ -417,7 +415,6 @@ static int chcr_ktls_dev_add(struct net_device *netdev, struct sock *sk, { struct tls_context *tls_ctx = tls_get_ctx(sk); struct ch_ktls_port_stats_debug *port_stats; - struct chcr_ktls_ofld_ctx_tx *tx_ctx; struct chcr_ktls_uld_ctx *u_ctx; struct chcr_ktls_info *tx_info; struct dst_entry *dst; @@ -427,8 +424,6 @@ static int chcr_ktls_dev_add(struct net_device *netdev, struct sock *sk, u8 daaddr[16]; int ret = -1; - tx_ctx = chcr_get_ktls_tx_context(tls_ctx); - pi = netdev_priv(netdev); adap = pi->adapter; port_stats = &adap->ch_ktls_stats.ktls_port[pi->port_id]; @@ -440,7 +435,7 @@ static int chcr_ktls_dev_add(struct net_device *netdev, struct sock *sk, goto out; } - if (tx_ctx->chcr_info) + if (chcr_get_ktls_tx_info(tls_ctx)) goto out; if (u_ctx && u_ctx->detach) @@ -566,7 +561,7 @@ static int chcr_ktls_dev_add(struct net_device *netdev, struct sock *sk, goto free_tid; atomic64_inc(&port_stats->ktls_tx_ctx); - tx_ctx->chcr_info = tx_info; + chcr_set_ktls_tx_info(tls_ctx, tx_info); return 0; @@ -647,7 +642,7 @@ static int chcr_ktls_cpl_act_open_rpl(struct adapter *adap, { const struct cpl_act_open_rpl *p = (void *)input; struct chcr_ktls_info *tx_info = NULL; - struct chcr_ktls_ofld_ctx_tx *tx_ctx; + struct tls_offload_context_tx *tx_ctx; struct chcr_ktls_uld_ctx *u_ctx; unsigned int atid, tid, status; struct tls_context *tls_ctx; @@ -686,7 +681,7 @@ static int chcr_ktls_cpl_act_open_rpl(struct adapter *adap, cxgb4_insert_tid(t, tx_info, tx_info->tid, tx_info->ip_family); /* Adding tid */ tls_ctx = tls_get_ctx(tx_info->sk); - tx_ctx = chcr_get_ktls_tx_context(tls_ctx); + tx_ctx = tls_offload_ctx_tx(tls_ctx); u_ctx = adap->uld[CXGB4_ULD_KTLS].handle; if (u_ctx) { ret = xa_insert_bh(&u_ctx->tid_list, tid, tx_ctx, @@ -1924,7 +1919,7 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev) { u32 tls_end_offset, tcp_seq, skb_data_len, skb_offset; struct ch_ktls_port_stats_debug *port_stats; - struct chcr_ktls_ofld_ctx_tx *tx_ctx; + struct tls_offload_context_tx *tx_ctx; struct ch_ktls_stats_debug *stats; struct tcphdr *th = tcp_hdr(skb); int data_len, qidx, ret = 0, mss; @@ -1944,6 +1939,7 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev) mss = skb_is_gso(skb) ? skb_shinfo(skb)->gso_size : data_len; tls_ctx = tls_get_ctx(skb->sk); + tx_ctx = tls_offload_ctx_tx(tls_ctx); tls_netdev = rcu_dereference_bh(tls_ctx->netdev); /* Don't quit on NULL: if tls_device_down is running in parallel, * netdev might become NULL, even if tls_is_skb_tx_device_offloaded was @@ -1952,8 +1948,7 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev) if (unlikely(tls_netdev && tls_netdev != dev)) goto out; - tx_ctx = chcr_get_ktls_tx_context(tls_ctx); - tx_info = tx_ctx->chcr_info; + tx_info = chcr_get_ktls_tx_info(tls_ctx); if (unlikely(!tx_info)) goto out; @@ -1979,19 +1974,19 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev) * we will send the complete record again. */ - spin_lock_irqsave(&tx_ctx->base.lock, flags); + spin_lock_irqsave(&tx_ctx->lock, flags); do { cxgb4_reclaim_completed_tx(adap, &q->q, true); /* fetch the tls record */ - record = tls_get_record(&tx_ctx->base, tcp_seq, + record = tls_get_record(tx_ctx, tcp_seq, &tx_info->record_no); /* By the time packet reached to us, ACK is received, and record * won't be found in that case, handle it gracefully. */ if (unlikely(!record)) { - spin_unlock_irqrestore(&tx_ctx->base.lock, flags); + spin_unlock_irqrestore(&tx_ctx->lock, flags); atomic64_inc(&port_stats->ktls_tx_drop_no_sync_data); goto out; } @@ -2015,7 +2010,7 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev) tls_end_offset != record->len); if (ret) { - spin_unlock_irqrestore(&tx_ctx->base.lock, + spin_unlock_irqrestore(&tx_ctx->lock, flags); goto out; } @@ -2046,7 +2041,7 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev) /* free the refcount taken earlier */ if (tls_end_offset < data_len) dev_kfree_skb_any(skb); - spin_unlock_irqrestore(&tx_ctx->base.lock, flags); + spin_unlock_irqrestore(&tx_ctx->lock, flags); goto out; } @@ -2082,7 +2077,7 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev) /* if any failure, come out from the loop. */ if (ret) { - spin_unlock_irqrestore(&tx_ctx->base.lock, flags); + spin_unlock_irqrestore(&tx_ctx->lock, flags); if (th->fin) dev_kfree_skb_any(skb); @@ -2097,7 +2092,7 @@ static int chcr_ktls_xmit(struct sk_buff *skb, struct net_device *dev) } while (data_len > 0); - spin_unlock_irqrestore(&tx_ctx->base.lock, flags); + spin_unlock_irqrestore(&tx_ctx->lock, flags); atomic64_inc(&port_stats->ktls_tx_encrypted_packets); atomic64_add(skb_data_len, &port_stats->ktls_tx_encrypted_bytes); @@ -2185,17 +2180,17 @@ static void clear_conn_resources(struct chcr_ktls_info *tx_info) static void ch_ktls_reset_all_conn(struct chcr_ktls_uld_ctx *u_ctx) { struct ch_ktls_port_stats_debug *port_stats; - struct chcr_ktls_ofld_ctx_tx *tx_ctx; + struct tls_offload_context_tx *tx_ctx; struct chcr_ktls_info *tx_info; unsigned long index; xa_for_each(&u_ctx->tid_list, index, tx_ctx) { - tx_info = tx_ctx->chcr_info; + tx_info = __chcr_get_ktls_tx_info(tx_ctx); clear_conn_resources(tx_info); port_stats = &tx_info->adap->ch_ktls_stats.ktls_port[tx_info->port_id]; atomic64_inc(&port_stats->ktls_tx_connection_close); kvfree(tx_info); - tx_ctx->chcr_info = NULL; + memset(tx_ctx->driver_state, 0, TLS_DRIVER_STATE_SIZE_TX); /* release module refcount */ module_put(THIS_MODULE); } diff --git a/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.h b/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.h index 10572dc55365..dbbba92bf540 100644 --- a/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.h +++ b/drivers/net/ethernet/chelsio/inline_crypto/ch_ktls/chcr_ktls.h @@ -67,8 +67,7 @@ struct chcr_ktls_info { bool pending_close; }; -struct chcr_ktls_ofld_ctx_tx { - struct tls_offload_context_tx base; +struct chcr_ktls_ctx_tx { struct chcr_ktls_info *chcr_info; }; @@ -79,14 +78,33 @@ struct chcr_ktls_uld_ctx { bool detach; }; -static inline struct chcr_ktls_ofld_ctx_tx * -chcr_get_ktls_tx_context(struct tls_context *tls_ctx) +static inline struct chcr_ktls_info * +__chcr_get_ktls_tx_info(struct tls_offload_context_tx *octx) { - BUILD_BUG_ON(sizeof(struct chcr_ktls_ofld_ctx_tx) > - TLS_OFFLOAD_CONTEXT_SIZE_TX); - return container_of(tls_offload_ctx_tx(tls_ctx), - struct chcr_ktls_ofld_ctx_tx, - base); + struct chcr_ktls_ctx_tx *priv_ctx; + + BUILD_BUG_ON(sizeof(struct chcr_ktls_ctx_tx) > TLS_DRIVER_STATE_SIZE_TX); + priv_ctx = (struct chcr_ktls_ctx_tx *)octx->driver_state; + return priv_ctx->chcr_info; +} + +static inline struct chcr_ktls_info * +chcr_get_ktls_tx_info(struct tls_context *tls_ctx) +{ + struct chcr_ktls_ctx_tx *priv_ctx; + + BUILD_BUG_ON(sizeof(struct chcr_ktls_ctx_tx) > TLS_DRIVER_STATE_SIZE_TX); + priv_ctx = (struct chcr_ktls_ctx_tx *)__tls_driver_ctx(tls_ctx, TLS_OFFLOAD_CTX_DIR_TX); + return priv_ctx->chcr_info; +} + +static inline void +chcr_set_ktls_tx_info(struct tls_context *tls_ctx, struct chcr_ktls_info *chcr_info) +{ + struct chcr_ktls_ctx_tx *priv_ctx; + + priv_ctx = __tls_driver_ctx(tls_ctx, TLS_OFFLOAD_CTX_DIR_TX); + priv_ctx->chcr_info = chcr_info; } static inline int chcr_get_first_rx_qid(struct adapter *adap) diff --git a/include/net/tls.h b/include/net/tls.h index a2b44578dcb7..962f0c501111 100644 --- a/include/net/tls.h +++ b/include/net/tls.h @@ -61,7 +61,8 @@ struct tls_rec; #define TLS_AAD_SPACE_SIZE 13 -#define MAX_IV_SIZE 16 +#define TLS_MAX_IV_SIZE 16 +#define TLS_MAX_SALT_SIZE 4 #define TLS_TAG_SIZE 16 #define TLS_MAX_REC_SEQ_SIZE 8 #define TLS_MAX_AAD_SIZE TLS_AAD_SPACE_SIZE @@ -149,6 +150,7 @@ struct tls_record_info { skb_frag_t frags[MAX_SKB_FRAGS]; }; +#define TLS_DRIVER_STATE_SIZE_TX 16 struct tls_offload_context_tx { struct crypto_aead *aead_send; spinlock_t lock; /* protects records list */ @@ -162,17 +164,13 @@ struct tls_offload_context_tx { void (*sk_destruct)(struct sock *sk); struct work_struct destruct_work; struct tls_context *ctx; - u8 driver_state[] __aligned(8); /* The TLS layer reserves room for driver specific state * Currently the belief is that there is not enough * driver specific state to justify another layer of indirection */ -#define TLS_DRIVER_STATE_SIZE_TX 16 + u8 driver_state[TLS_DRIVER_STATE_SIZE_TX] __aligned(8); }; -#define TLS_OFFLOAD_CONTEXT_SIZE_TX \ - (sizeof(struct tls_offload_context_tx) + TLS_DRIVER_STATE_SIZE_TX) - enum tls_context_flags { /* tls_device_down was called after the netdev went down, device state * was released, and kTLS works in software, even though rx_conf is @@ -193,8 +191,8 @@ enum tls_context_flags { }; struct cipher_context { - char *iv; - char *rec_seq; + char iv[TLS_MAX_IV_SIZE + TLS_MAX_SALT_SIZE]; + char rec_seq[TLS_MAX_REC_SEQ_SIZE]; }; union tls_crypto_context { @@ -302,6 +300,7 @@ struct tls_offload_resync_async { u32 log[TLS_DEVICE_RESYNC_ASYNC_LOGMAX]; }; +#define TLS_DRIVER_STATE_SIZE_RX 8 struct tls_offload_context_rx { /* sw must be the first member of tls_offload_context_rx */ struct tls_sw_context_rx sw; @@ -325,17 +324,13 @@ struct tls_offload_context_rx { struct tls_offload_resync_async *resync_async; }; }; - u8 driver_state[] __aligned(8); /* The TLS layer reserves room for driver specific state * Currently the belief is that there is not enough * driver specific state to justify another layer of indirection */ -#define TLS_DRIVER_STATE_SIZE_RX 8 + u8 driver_state[TLS_DRIVER_STATE_SIZE_RX] __aligned(8); }; -#define TLS_OFFLOAD_CONTEXT_SIZE_RX \ - (sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX) - struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, u32 seq, u64 *p_record_sn); diff --git a/net/tls/tls.h b/net/tls/tls.h index 28a8c0e80e3c..478b2c0060aa 100644 --- a/net/tls/tls.h +++ b/net/tls/tls.h @@ -127,7 +127,7 @@ struct tls_rec { struct sock *sk; char aad_space[TLS_AAD_SPACE_SIZE]; - u8 iv_data[MAX_IV_SIZE]; + u8 iv_data[TLS_MAX_IV_SIZE]; struct aead_request aead_req; u8 aead_req_ctx[]; }; @@ -142,7 +142,11 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx); int wait_on_pending_writer(struct sock *sk, long *timeo); void tls_err_abort(struct sock *sk, int err); -int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx); +int init_prot_info(struct tls_prot_info *prot, + const struct tls_crypto_info *crypto_info, + const struct tls_cipher_desc *cipher_desc, + int mode); +int tls_set_sw_offload(struct sock *sk, int tx); void tls_update_rx_zc_capable(struct tls_context *tls_ctx); void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx); void tls_sw_strparser_done(struct tls_context *tls_ctx); @@ -223,7 +227,7 @@ static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx) #ifdef CONFIG_TLS_DEVICE int tls_device_init(void); void tls_device_cleanup(void); -int tls_set_device_offload(struct sock *sk, struct tls_context *ctx); +int tls_set_device_offload(struct sock *sk); void tls_device_free_resources_tx(struct sock *sk); int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx); void tls_device_offload_cleanup_rx(struct sock *sk); @@ -234,7 +238,7 @@ static inline int tls_device_init(void) { return 0; } static inline void tls_device_cleanup(void) {} static inline int -tls_set_device_offload(struct sock *sk, struct tls_context *ctx) +tls_set_device_offload(struct sock *sk) { return -EOPNOTSUPP; } diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index 8c94c926606a..f01543557a60 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -56,11 +56,8 @@ static struct page *dummy_page; static void tls_device_free_ctx(struct tls_context *ctx) { - if (ctx->tx_conf == TLS_HW) { + if (ctx->tx_conf == TLS_HW) kfree(tls_offload_ctx_tx(ctx)); - kfree(ctx->tx.rec_seq); - kfree(ctx->tx.iv); - } if (ctx->rx_conf == TLS_HW) kfree(tls_offload_ctx_rx(ctx)); @@ -891,14 +888,8 @@ tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx) struct strp_msg *rxm; char *orig_buf, *buf; - switch (tls_ctx->crypto_recv.info.cipher_type) { - case TLS_CIPHER_AES_GCM_128: - case TLS_CIPHER_AES_GCM_256: - break; - default: - return -EINVAL; - } cipher_desc = get_cipher_desc(tls_ctx->crypto_recv.info.cipher_type); + DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); rxm = strp_msg(tls_strp_msg(sw_ctx)); orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_desc->iv, @@ -1042,22 +1033,45 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk, } } -int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) +static struct tls_offload_context_tx *alloc_offload_ctx_tx(struct tls_context *ctx) +{ + struct tls_offload_context_tx *offload_ctx; + __be64 rcd_sn; + + offload_ctx = kzalloc(sizeof(*offload_ctx), GFP_KERNEL); + if (!offload_ctx) + return NULL; + + INIT_WORK(&offload_ctx->destruct_work, tls_device_tx_del_task); + INIT_LIST_HEAD(&offload_ctx->records_list); + spin_lock_init(&offload_ctx->lock); + sg_init_table(offload_ctx->sg_tx_data, + ARRAY_SIZE(offload_ctx->sg_tx_data)); + + /* start at rec_seq - 1 to account for the start marker record */ + memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn)); + offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1; + + offload_ctx->ctx = ctx; + + return offload_ctx; +} + +int tls_set_device_offload(struct sock *sk) { - struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_prot_info *prot = &tls_ctx->prot_info; - const struct tls_cipher_desc *cipher_desc; struct tls_record_info *start_marker_record; struct tls_offload_context_tx *offload_ctx; + const struct tls_cipher_desc *cipher_desc; struct tls_crypto_info *crypto_info; + struct tls_prot_info *prot; struct net_device *netdev; - char *iv, *rec_seq; + struct tls_context *ctx; struct sk_buff *skb; - __be64 rcd_sn; + char *iv, *rec_seq; int rc; - if (!ctx) - return -EINVAL; + ctx = tls_get_ctx(sk); + prot = &ctx->prot_info; if (ctx->priv_ctx_tx) return -EEXIST; @@ -1085,38 +1099,23 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) goto release_netdev; } + rc = init_prot_info(prot, crypto_info, cipher_desc, TLS_HW); + if (rc) + goto release_netdev; + iv = crypto_info_iv(crypto_info, cipher_desc); rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc); - prot->version = crypto_info->version; - prot->cipher_type = crypto_info->cipher_type; - prot->prepend_size = TLS_HEADER_SIZE + cipher_desc->iv; - prot->tag_size = cipher_desc->tag; - prot->overhead_size = prot->prepend_size + prot->tag_size; - prot->iv_size = cipher_desc->iv; - prot->salt_size = cipher_desc->salt; - ctx->tx.iv = kmalloc(cipher_desc->iv + cipher_desc->salt, GFP_KERNEL); - if (!ctx->tx.iv) { - rc = -ENOMEM; - goto release_netdev; - } - memcpy(ctx->tx.iv + cipher_desc->salt, iv, cipher_desc->iv); - - prot->rec_seq_size = cipher_desc->rec_seq; - ctx->tx.rec_seq = kmemdup(rec_seq, cipher_desc->rec_seq, GFP_KERNEL); - if (!ctx->tx.rec_seq) { - rc = -ENOMEM; - goto free_iv; - } + memcpy(ctx->tx.rec_seq, rec_seq, cipher_desc->rec_seq); start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL); if (!start_marker_record) { rc = -ENOMEM; - goto free_rec_seq; + goto release_netdev; } - offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL); + offload_ctx = alloc_offload_ctx_tx(ctx); if (!offload_ctx) { rc = -ENOMEM; goto free_marker_record; @@ -1126,22 +1125,10 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) if (rc) goto free_offload_ctx; - /* start at rec_seq - 1 to account for the start marker record */ - memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn)); - offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1; - start_marker_record->end_seq = tcp_sk(sk)->write_seq; start_marker_record->len = 0; start_marker_record->num_frags = 0; - - INIT_WORK(&offload_ctx->destruct_work, tls_device_tx_del_task); - offload_ctx->ctx = ctx; - - INIT_LIST_HEAD(&offload_ctx->records_list); list_add_tail(&start_marker_record->list, &offload_ctx->records_list); - spin_lock_init(&offload_ctx->lock); - sg_init_table(offload_ctx->sg_tx_data, - ARRAY_SIZE(offload_ctx->sg_tx_data)); clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked); ctx->push_pending_record = tls_device_push_pending_record; @@ -1198,10 +1185,6 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) ctx->priv_ctx_tx = NULL; free_marker_record: kfree(start_marker_record); -free_rec_seq: - kfree(ctx->tx.rec_seq); -free_iv: - kfree(ctx->tx.iv); release_netdev: dev_put(netdev); return rc; @@ -1242,7 +1225,7 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) goto release_lock; } - context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL); + context = kzalloc(sizeof(*context), GFP_KERNEL); if (!context) { rc = -ENOMEM; goto release_lock; @@ -1250,7 +1233,7 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) context->resync_nh_reset = 1; ctx->priv_ctx_rx = context; - rc = tls_set_sw_offload(sk, ctx, 0); + rc = tls_set_sw_offload(sk, 0); if (rc) goto release_ctx; diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c index 1d743f310f4f..4e7228f275fa 100644 --- a/net/tls/tls_device_fallback.c +++ b/net/tls/tls_device_fallback.c @@ -54,7 +54,7 @@ static int tls_enc_record(struct aead_request *aead_req, struct scatter_walk *out, int *in_len, struct tls_prot_info *prot) { - unsigned char buf[TLS_HEADER_SIZE + MAX_IV_SIZE]; + unsigned char buf[TLS_HEADER_SIZE + TLS_MAX_IV_SIZE]; const struct tls_cipher_desc *cipher_desc; struct scatterlist sg_in[3]; struct scatterlist sg_out[3]; @@ -62,14 +62,8 @@ static int tls_enc_record(struct aead_request *aead_req, u16 len; int rc; - switch (prot->cipher_type) { - case TLS_CIPHER_AES_GCM_128: - case TLS_CIPHER_AES_GCM_256: - break; - default: - return -EINVAL; - } cipher_desc = get_cipher_desc(prot->cipher_type); + DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); buf_size = TLS_HEADER_SIZE + cipher_desc->iv; len = min_t(int, *in_len, buf_size); @@ -338,17 +332,9 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, if (!aead_req) return NULL; - switch (tls_ctx->crypto_send.info.cipher_type) { - case TLS_CIPHER_AES_GCM_128: - salt = tls_ctx->crypto_send.aes_gcm_128.salt; - break; - case TLS_CIPHER_AES_GCM_256: - salt = tls_ctx->crypto_send.aes_gcm_256.salt; - break; - default: - goto free_req; - } cipher_desc = get_cipher_desc(tls_ctx->crypto_send.info.cipher_type); + DEBUG_NET_WARN_ON_ONCE(!cipher_desc || !cipher_desc->offloadable); + buf_len = cipher_desc->salt + cipher_desc->iv + TLS_AAD_SPACE_SIZE + sync_size + cipher_desc->tag; buf = kmalloc(buf_len, GFP_ATOMIC); @@ -356,6 +342,7 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, goto free_req; iv = buf; + salt = crypto_info_salt(&tls_ctx->crypto_send.info, cipher_desc); memcpy(iv, salt, cipher_desc->salt); aad = buf + cipher_desc->salt + cipher_desc->iv; dummy_buf = aad + TLS_AAD_SPACE_SIZE; diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 02f583ff9239..b125a08a618a 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -59,7 +59,8 @@ enum { }; #define CHECK_CIPHER_DESC(cipher,ci) \ - static_assert(cipher ## _IV_SIZE <= MAX_IV_SIZE); \ + static_assert(cipher ## _IV_SIZE <= TLS_MAX_IV_SIZE); \ + static_assert(cipher ## _SALT_SIZE <= TLS_MAX_SALT_SIZE); \ static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE); \ static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE); \ static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE); \ @@ -344,8 +345,6 @@ static void tls_sk_proto_cleanup(struct sock *sk, /* We need these for tls_sw_fallback handling of other packets */ if (ctx->tx_conf == TLS_SW) { - kfree(ctx->tx.rec_seq); - kfree(ctx->tx.iv); tls_sw_release_resources_tx(sk); TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); } else if (ctx->tx_conf == TLS_HW) { @@ -581,6 +580,31 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, return do_tls_getsockopt(sk, optname, optval, optlen); } +static int validate_crypto_info(const struct tls_crypto_info *crypto_info, + const struct tls_crypto_info *alt_crypto_info) +{ + if (crypto_info->version != TLS_1_2_VERSION && + crypto_info->version != TLS_1_3_VERSION) + return -EINVAL; + + switch (crypto_info->cipher_type) { + case TLS_CIPHER_ARIA_GCM_128: + case TLS_CIPHER_ARIA_GCM_256: + if (crypto_info->version != TLS_1_2_VERSION) + return -EINVAL; + break; + } + + /* Ensure that TLS version and ciphers are same in both directions */ + if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { + if (alt_crypto_info->version != crypto_info->version || + alt_crypto_info->cipher_type != crypto_info->cipher_type) + return -EINVAL; + } + + return 0; +} + static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, unsigned int optlen, int tx) { @@ -612,21 +636,9 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, goto err_crypto_info; } - /* check version */ - if (crypto_info->version != TLS_1_2_VERSION && - crypto_info->version != TLS_1_3_VERSION) { - rc = -EINVAL; + rc = validate_crypto_info(crypto_info, alt_crypto_info); + if (rc) goto err_crypto_info; - } - - /* Ensure that TLS version and ciphers are same in both directions */ - if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { - if (alt_crypto_info->version != crypto_info->version || - alt_crypto_info->cipher_type != crypto_info->cipher_type) { - rc = -EINVAL; - goto err_crypto_info; - } - } cipher_desc = get_cipher_desc(crypto_info->cipher_type); if (!cipher_desc) { @@ -634,16 +646,6 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, goto err_crypto_info; } - switch (crypto_info->cipher_type) { - case TLS_CIPHER_ARIA_GCM_128: - case TLS_CIPHER_ARIA_GCM_256: - if (crypto_info->version != TLS_1_2_VERSION) { - rc = -EINVAL; - goto err_crypto_info; - } - break; - } - if (optlen != cipher_desc->crypto_info) { rc = -EINVAL; goto err_crypto_info; @@ -658,13 +660,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, } if (tx) { - rc = tls_set_device_offload(sk, ctx); + rc = tls_set_device_offload(sk); conf = TLS_HW; if (!rc) { TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); } else { - rc = tls_set_sw_offload(sk, ctx, 1); + rc = tls_set_sw_offload(sk, 1); if (rc) goto err_crypto_info; TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); @@ -678,7 +680,7 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); } else { - rc = tls_set_sw_offload(sk, ctx, 0); + rc = tls_set_sw_offload(sk, 0); if (rc) goto err_crypto_info; TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 270712b8d391..0f6da4ce3ed7 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -60,7 +60,7 @@ struct tls_decrypt_arg { struct tls_decrypt_ctx { struct sock *sk; - u8 iv[MAX_IV_SIZE]; + u8 iv[TLS_MAX_IV_SIZE]; u8 aad[TLS_MAX_AAD_SIZE]; u8 tail; struct scatterlist sg[]; @@ -2319,7 +2319,7 @@ int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb) { struct tls_context *tls_ctx = tls_get_ctx(strp->sk); struct tls_prot_info *prot = &tls_ctx->prot_info; - char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; + char header[TLS_HEADER_SIZE + TLS_MAX_IV_SIZE]; size_t cipher_overhead; size_t data_len = 0; int ret; @@ -2467,9 +2467,6 @@ void tls_sw_release_resources_rx(struct sock *sk) struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - kfree(tls_ctx->rx.rec_seq); - kfree(tls_ctx->rx.iv); - if (ctx->aead_recv) { __skb_queue_purge(&ctx->rx_list); crypto_free_aead(ctx->aead_recv); @@ -2581,84 +2578,54 @@ void tls_update_rx_zc_capable(struct tls_context *tls_ctx) tls_ctx->prot_info.version != TLS_1_3_VERSION; } -int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) +static struct tls_sw_context_tx *init_ctx_tx(struct tls_context *ctx, struct sock *sk) { - struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_prot_info *prot = &tls_ctx->prot_info; - struct tls_crypto_info *crypto_info; - struct tls_sw_context_tx *sw_ctx_tx = NULL; - struct tls_sw_context_rx *sw_ctx_rx = NULL; - struct cipher_context *cctx; - struct crypto_aead **aead; - struct crypto_tfm *tfm; - char *iv, *rec_seq, *key, *salt; - const struct tls_cipher_desc *cipher_desc; - u16 nonce_size; - int rc = 0; + struct tls_sw_context_tx *sw_ctx_tx; - if (!ctx) { - rc = -EINVAL; - goto out; - } - - if (tx) { - if (!ctx->priv_ctx_tx) { - sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); - if (!sw_ctx_tx) { - rc = -ENOMEM; - goto out; - } - ctx->priv_ctx_tx = sw_ctx_tx; - } else { - sw_ctx_tx = - (struct tls_sw_context_tx *)ctx->priv_ctx_tx; - } + if (!ctx->priv_ctx_tx) { + sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); + if (!sw_ctx_tx) + return NULL; } else { - if (!ctx->priv_ctx_rx) { - sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); - if (!sw_ctx_rx) { - rc = -ENOMEM; - goto out; - } - ctx->priv_ctx_rx = sw_ctx_rx; - } else { - sw_ctx_rx = - (struct tls_sw_context_rx *)ctx->priv_ctx_rx; - } + sw_ctx_tx = ctx->priv_ctx_tx; } - if (tx) { - crypto_init_wait(&sw_ctx_tx->async_wait); - spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); - crypto_info = &ctx->crypto_send.info; - cctx = &ctx->tx; - aead = &sw_ctx_tx->aead_send; - INIT_LIST_HEAD(&sw_ctx_tx->tx_list); - INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); - sw_ctx_tx->tx_work.sk = sk; + crypto_init_wait(&sw_ctx_tx->async_wait); + spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); + INIT_LIST_HEAD(&sw_ctx_tx->tx_list); + INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); + sw_ctx_tx->tx_work.sk = sk; + + return sw_ctx_tx; +} + +static struct tls_sw_context_rx *init_ctx_rx(struct tls_context *ctx) +{ + struct tls_sw_context_rx *sw_ctx_rx; + + if (!ctx->priv_ctx_rx) { + sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); + if (!sw_ctx_rx) + return NULL; } else { - crypto_init_wait(&sw_ctx_rx->async_wait); - spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); - init_waitqueue_head(&sw_ctx_rx->wq); - crypto_info = &ctx->crypto_recv.info; - cctx = &ctx->rx; - skb_queue_head_init(&sw_ctx_rx->rx_list); - skb_queue_head_init(&sw_ctx_rx->async_hold); - aead = &sw_ctx_rx->aead_recv; + sw_ctx_rx = ctx->priv_ctx_rx; } - cipher_desc = get_cipher_desc(crypto_info->cipher_type); - if (!cipher_desc) { - rc = -EINVAL; - goto free_priv; - } + crypto_init_wait(&sw_ctx_rx->async_wait); + spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); + init_waitqueue_head(&sw_ctx_rx->wq); + skb_queue_head_init(&sw_ctx_rx->rx_list); + skb_queue_head_init(&sw_ctx_rx->async_hold); - nonce_size = cipher_desc->nonce; + return sw_ctx_rx; +} - iv = crypto_info_iv(crypto_info, cipher_desc); - key = crypto_info_key(crypto_info, cipher_desc); - salt = crypto_info_salt(crypto_info, cipher_desc); - rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc); +int init_prot_info(struct tls_prot_info *prot, + const struct tls_crypto_info *crypto_info, + const struct tls_cipher_desc *cipher_desc, + int mode) +{ + u16 nonce_size = cipher_desc->nonce; if (crypto_info->version == TLS_1_3_VERSION) { nonce_size = 0; @@ -2669,42 +2636,89 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) prot->tail_size = 0; } - /* Sanity-check the sizes for stack allocations. */ - if (nonce_size > MAX_IV_SIZE || prot->aad_size > TLS_MAX_AAD_SIZE) { - rc = -EINVAL; - goto free_priv; + if (mode == TLS_HW) { + prot->aad_size = 0; + prot->tail_size = 0; } + /* Sanity-check the sizes for stack allocations. */ + if (nonce_size > TLS_MAX_IV_SIZE || prot->aad_size > TLS_MAX_AAD_SIZE) + return -EINVAL; + prot->version = crypto_info->version; prot->cipher_type = crypto_info->cipher_type; prot->prepend_size = TLS_HEADER_SIZE + nonce_size; prot->tag_size = cipher_desc->tag; - prot->overhead_size = prot->prepend_size + - prot->tag_size + prot->tail_size; + prot->overhead_size = prot->prepend_size + prot->tag_size + prot->tail_size; prot->iv_size = cipher_desc->iv; prot->salt_size = cipher_desc->salt; - cctx->iv = kmalloc(cipher_desc->iv + cipher_desc->salt, GFP_KERNEL); - if (!cctx->iv) { - rc = -ENOMEM; + prot->rec_seq_size = cipher_desc->rec_seq; + + return 0; +} + +int tls_set_sw_offload(struct sock *sk, int tx) +{ + struct tls_sw_context_tx *sw_ctx_tx = NULL; + struct tls_sw_context_rx *sw_ctx_rx = NULL; + const struct tls_cipher_desc *cipher_desc; + struct tls_crypto_info *crypto_info; + char *iv, *rec_seq, *key, *salt; + struct cipher_context *cctx; + struct tls_prot_info *prot; + struct crypto_aead **aead; + struct tls_context *ctx; + struct crypto_tfm *tfm; + int rc = 0; + + ctx = tls_get_ctx(sk); + prot = &ctx->prot_info; + + if (tx) { + ctx->priv_ctx_tx = init_ctx_tx(ctx, sk); + if (!ctx->priv_ctx_tx) + return -ENOMEM; + + sw_ctx_tx = ctx->priv_ctx_tx; + crypto_info = &ctx->crypto_send.info; + cctx = &ctx->tx; + aead = &sw_ctx_tx->aead_send; + } else { + ctx->priv_ctx_rx = init_ctx_rx(ctx); + if (!ctx->priv_ctx_rx) + return -ENOMEM; + + sw_ctx_rx = ctx->priv_ctx_rx; + crypto_info = &ctx->crypto_recv.info; + cctx = &ctx->rx; + aead = &sw_ctx_rx->aead_recv; + } + + cipher_desc = get_cipher_desc(crypto_info->cipher_type); + if (!cipher_desc) { + rc = -EINVAL; goto free_priv; } - /* Note: 128 & 256 bit salt are the same size */ - prot->rec_seq_size = cipher_desc->rec_seq; + + rc = init_prot_info(prot, crypto_info, cipher_desc, TLS_SW); + if (rc) + goto free_priv; + + iv = crypto_info_iv(crypto_info, cipher_desc); + key = crypto_info_key(crypto_info, cipher_desc); + salt = crypto_info_salt(crypto_info, cipher_desc); + rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc); + memcpy(cctx->iv, salt, cipher_desc->salt); memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv); - - cctx->rec_seq = kmemdup(rec_seq, cipher_desc->rec_seq, GFP_KERNEL); - if (!cctx->rec_seq) { - rc = -ENOMEM; - goto free_iv; - } + memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq); if (!*aead) { *aead = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0); if (IS_ERR(*aead)) { rc = PTR_ERR(*aead); *aead = NULL; - goto free_rec_seq; + goto free_priv; } } @@ -2736,12 +2750,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) free_aead: crypto_free_aead(*aead); *aead = NULL; -free_rec_seq: - kfree(cctx->rec_seq); - cctx->rec_seq = NULL; -free_iv: - kfree(cctx->iv); - cctx->iv = NULL; free_priv: if (tx) { kfree(ctx->priv_ctx_tx);