crypto: arm64/gcm-aes-ce - remove non-SIMD fallback path

Now that kernel mode SIMD is guaranteed to be available when executing
in task or softirq context, we no longer need scalar fallbacks to use
when the NEON is unavailable. So get rid of them.

Reviewed-by: Eric Biggers <ebiggers@google.com>
Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
This commit is contained in:
Ard Biesheuvel 2021-08-27 09:03:36 +02:00 committed by Herbert Xu
parent 4a7e1e5fc2
commit b9e699f912
1 changed files with 55 additions and 162 deletions

View File

@ -362,84 +362,36 @@ static int gcm_encrypt(struct aead_request *req)
err = skcipher_walk_aead_encrypt(&walk, req, false);
if (likely(crypto_simd_usable())) {
do {
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
int nbytes = walk.nbytes;
tag = (u8 *)&lengths;
if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
src = dst = memcpy(buf + sizeof(buf) - nbytes,
src, nbytes);
} else if (nbytes < walk.total) {
nbytes &= ~(AES_BLOCK_SIZE - 1);
tag = NULL;
}
kernel_neon_begin();
pmull_gcm_encrypt(nbytes, dst, src, ctx->ghash_key.h,
dg, iv, ctx->aes_key.key_enc, nrounds,
tag);
kernel_neon_end();
if (unlikely(!nbytes))
break;
if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
memcpy(walk.dst.virt.addr,
buf + sizeof(buf) - nbytes, nbytes);
err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
} while (walk.nbytes);
} else {
while (walk.nbytes >= AES_BLOCK_SIZE) {
int blocks = walk.nbytes / AES_BLOCK_SIZE;
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
int remaining = blocks;
do {
aes_encrypt(&ctx->aes_key, buf, iv);
crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
crypto_inc(iv, AES_BLOCK_SIZE);
dst += AES_BLOCK_SIZE;
src += AES_BLOCK_SIZE;
} while (--remaining > 0);
ghash_do_update(blocks, dg, walk.dst.virt.addr,
&ctx->ghash_key, NULL);
err = skcipher_walk_done(&walk,
walk.nbytes % AES_BLOCK_SIZE);
}
/* handle the tail */
if (walk.nbytes) {
aes_encrypt(&ctx->aes_key, buf, iv);
crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
buf, walk.nbytes);
memcpy(buf, walk.dst.virt.addr, walk.nbytes);
memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
}
do {
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
int nbytes = walk.nbytes;
tag = (u8 *)&lengths;
ghash_do_update(1, dg, tag, &ctx->ghash_key,
walk.nbytes ? buf : NULL);
if (walk.nbytes)
err = skcipher_walk_done(&walk, 0);
if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
src = dst = memcpy(buf + sizeof(buf) - nbytes,
src, nbytes);
} else if (nbytes < walk.total) {
nbytes &= ~(AES_BLOCK_SIZE - 1);
tag = NULL;
}
put_unaligned_be64(dg[1], tag);
put_unaligned_be64(dg[0], tag + 8);
put_unaligned_be32(1, iv + GCM_IV_SIZE);
aes_encrypt(&ctx->aes_key, iv, iv);
crypto_xor(tag, iv, AES_BLOCK_SIZE);
}
kernel_neon_begin();
pmull_gcm_encrypt(nbytes, dst, src, ctx->ghash_key.h,
dg, iv, ctx->aes_key.key_enc, nrounds,
tag);
kernel_neon_end();
if (unlikely(!nbytes))
break;
if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
memcpy(walk.dst.virt.addr,
buf + sizeof(buf) - nbytes, nbytes);
err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
} while (walk.nbytes);
if (err)
return err;
@ -464,6 +416,7 @@ static int gcm_decrypt(struct aead_request *req)
u64 dg[2] = {};
be128 lengths;
u8 *tag;
int ret;
int err;
lengths.a = cpu_to_be64(req->assoclen * 8);
@ -481,101 +434,41 @@ static int gcm_decrypt(struct aead_request *req)
err = skcipher_walk_aead_decrypt(&walk, req, false);
if (likely(crypto_simd_usable())) {
int ret;
do {
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
int nbytes = walk.nbytes;
tag = (u8 *)&lengths;
if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
src = dst = memcpy(buf + sizeof(buf) - nbytes,
src, nbytes);
} else if (nbytes < walk.total) {
nbytes &= ~(AES_BLOCK_SIZE - 1);
tag = NULL;
}
kernel_neon_begin();
ret = pmull_gcm_decrypt(nbytes, dst, src,
ctx->ghash_key.h,
dg, iv, ctx->aes_key.key_enc,
nrounds, tag, otag, authsize);
kernel_neon_end();
if (unlikely(!nbytes))
break;
if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
memcpy(walk.dst.virt.addr,
buf + sizeof(buf) - nbytes, nbytes);
err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
} while (walk.nbytes);
if (err)
return err;
if (ret)
return -EBADMSG;
} else {
while (walk.nbytes >= AES_BLOCK_SIZE) {
int blocks = walk.nbytes / AES_BLOCK_SIZE;
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
ghash_do_update(blocks, dg, walk.src.virt.addr,
&ctx->ghash_key, NULL);
do {
aes_encrypt(&ctx->aes_key, buf, iv);
crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
crypto_inc(iv, AES_BLOCK_SIZE);
dst += AES_BLOCK_SIZE;
src += AES_BLOCK_SIZE;
} while (--blocks > 0);
err = skcipher_walk_done(&walk,
walk.nbytes % AES_BLOCK_SIZE);
}
/* handle the tail */
if (walk.nbytes) {
memcpy(buf, walk.src.virt.addr, walk.nbytes);
memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
}
do {
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
int nbytes = walk.nbytes;
tag = (u8 *)&lengths;
ghash_do_update(1, dg, tag, &ctx->ghash_key,
walk.nbytes ? buf : NULL);
if (walk.nbytes) {
aes_encrypt(&ctx->aes_key, buf, iv);
crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
buf, walk.nbytes);
err = skcipher_walk_done(&walk, 0);
if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
src = dst = memcpy(buf + sizeof(buf) - nbytes,
src, nbytes);
} else if (nbytes < walk.total) {
nbytes &= ~(AES_BLOCK_SIZE - 1);
tag = NULL;
}
if (err)
return err;
kernel_neon_begin();
ret = pmull_gcm_decrypt(nbytes, dst, src, ctx->ghash_key.h,
dg, iv, ctx->aes_key.key_enc,
nrounds, tag, otag, authsize);
kernel_neon_end();
put_unaligned_be64(dg[1], tag);
put_unaligned_be64(dg[0], tag + 8);
put_unaligned_be32(1, iv + GCM_IV_SIZE);
aes_encrypt(&ctx->aes_key, iv, iv);
crypto_xor(tag, iv, AES_BLOCK_SIZE);
if (unlikely(!nbytes))
break;
if (crypto_memneq(tag, otag, authsize)) {
memzero_explicit(tag, AES_BLOCK_SIZE);
return -EBADMSG;
}
}
return 0;
if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
memcpy(walk.dst.virt.addr,
buf + sizeof(buf) - nbytes, nbytes);
err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
} while (walk.nbytes);
if (err)
return err;
return ret ? -EBADMSG : 0;
}
static struct aead_alg gcm_aes_alg = {