crypto: arm/aes-neonbs-ctr - deal with non-multiples of AES block size

Instead of falling back to C code to deal with the final bit of input
that is not a round multiple of the block size, handle this in the asm
code, permitting us to use overlapping loads and stores for performance,
and implement the 16-byte wide XOR using a single NEON instruction.

Since NEON loads and stores have a natural width of 16 bytes, we need to
handle inputs of less than 16 bytes in a special way, but this rarely
occurs in practice so it does not impact performance. All other input
sizes can be consumed directly by the NEON asm code, although it should
be noted that the core AES transform can still only process 128 bytes (8
AES blocks) at a time.

Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
This commit is contained in:
Ard Biesheuvel 2022-01-27 12:35:43 +01:00 committed by Herbert Xu
parent 8daa399ede
commit c8bf850e99
2 changed files with 82 additions and 68 deletions

View File

@ -758,29 +758,24 @@ ENTRY(aesbs_cbc_decrypt)
ENDPROC(aesbs_cbc_decrypt)
.macro next_ctr, q
vmov.32 \q\()h[1], r10
vmov \q\()h, r9, r10
adds r10, r10, #1
vmov.32 \q\()h[0], r9
adcs r9, r9, #0
vmov.32 \q\()l[1], r8
vmov \q\()l, r7, r8
adcs r8, r8, #0
vmov.32 \q\()l[0], r7
adc r7, r7, #0
vrev32.8 \q, \q
.endm
/*
* aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
* int rounds, int blocks, u8 ctr[], u8 final[])
* int rounds, int bytes, u8 ctr[])
*/
ENTRY(aesbs_ctr_encrypt)
mov ip, sp
push {r4-r10, lr}
ldm ip, {r5-r7} // load args 4-6
teq r7, #0
addne r5, r5, #1 // one extra block if final != 0
ldm ip, {r5, r6} // load args 4-5
vld1.8 {q0}, [r6] // load counter
vrev32.8 q1, q0
vmov r9, r10, d3
@ -792,20 +787,19 @@ ENTRY(aesbs_ctr_encrypt)
adc r7, r7, #0
99: vmov q1, q0
vmov q2, q0
vmov q3, q0
vmov q4, q0
vmov q5, q0
vmov q6, q0
vmov q7, q0
adr ip, 0f
sub lr, r5, #1
and lr, lr, #7
cmp r5, #8
sub ip, ip, lr, lsl #5
sub ip, ip, lr, lsl #2
movlt pc, ip // computed goto if blocks < 8
vmov q2, q0
adr ip, 0f
vmov q3, q0
and lr, lr, #112
vmov q4, q0
cmp r5, #112
vmov q5, q0
sub ip, ip, lr, lsl #1
vmov q6, q0
add ip, ip, lr, lsr #2
vmov q7, q0
movle pc, ip // computed goto if bytes < 112
next_ctr q1
next_ctr q2
@ -820,12 +814,14 @@ ENTRY(aesbs_ctr_encrypt)
bl aesbs_encrypt8
adr ip, 1f
and lr, r5, #7
cmp r5, #8
movgt r4, #0
ldrle r4, [sp, #40] // load final in the last round
sub ip, ip, lr, lsl #2
movlt pc, ip // computed goto if blocks < 8
sub lr, r5, #1
cmp r5, #128
bic lr, lr, #15
ands r4, r5, #15 // preserves C flag
teqcs r5, r5 // set Z flag if not last iteration
sub ip, ip, lr, lsr #2
rsb r4, r4, #16
movcc pc, ip // computed goto if bytes < 128
vld1.8 {q8}, [r1]!
vld1.8 {q9}, [r1]!
@ -834,46 +830,70 @@ ENTRY(aesbs_ctr_encrypt)
vld1.8 {q12}, [r1]!
vld1.8 {q13}, [r1]!
vld1.8 {q14}, [r1]!
teq r4, #0 // skip last block if 'final'
1: bne 2f
1: subne r1, r1, r4
vld1.8 {q15}, [r1]!
2: adr ip, 3f
cmp r5, #8
sub ip, ip, lr, lsl #3
movlt pc, ip // computed goto if blocks < 8
add ip, ip, #2f - 1b
veor q0, q0, q8
vst1.8 {q0}, [r0]!
veor q1, q1, q9
vst1.8 {q1}, [r0]!
veor q4, q4, q10
vst1.8 {q4}, [r0]!
veor q6, q6, q11
vst1.8 {q6}, [r0]!
veor q3, q3, q12
vst1.8 {q3}, [r0]!
veor q7, q7, q13
vst1.8 {q7}, [r0]!
veor q2, q2, q14
bne 3f
veor q5, q5, q15
movcc pc, ip // computed goto if bytes < 128
vst1.8 {q0}, [r0]!
vst1.8 {q1}, [r0]!
vst1.8 {q4}, [r0]!
vst1.8 {q6}, [r0]!
vst1.8 {q3}, [r0]!
vst1.8 {q7}, [r0]!
vst1.8 {q2}, [r0]!
teq r4, #0 // skip last block if 'final'
W(bne) 5f
3: veor q5, q5, q15
2: subne r0, r0, r4
vst1.8 {q5}, [r0]!
4: next_ctr q0
next_ctr q0
subs r5, r5, #8
subs r5, r5, #128
bgt 99b
vst1.8 {q0}, [r6]
pop {r4-r10, pc}
5: vst1.8 {q5}, [r4]
b 4b
3: adr lr, .Lpermute_table + 16
cmp r5, #16 // Z flag remains cleared
sub lr, lr, r4
vld1.8 {q8-q9}, [lr]
vtbl.8 d16, {q5}, d16
vtbl.8 d17, {q5}, d17
veor q5, q8, q15
bcc 4f // have to reload prev if R5 < 16
vtbx.8 d10, {q2}, d18
vtbx.8 d11, {q2}, d19
mov pc, ip // branch back to VST sequence
4: sub r0, r0, r4
vshr.s8 q9, q9, #7 // create mask for VBIF
vld1.8 {q8}, [r0] // reload
vbif q5, q8, q9
vst1.8 {q5}, [r0]
pop {r4-r10, pc}
ENDPROC(aesbs_ctr_encrypt)
.align 6
.Lpermute_table:
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
.byte 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.byte 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
.macro next_tweak, out, in, const, tmp
vshr.s64 \tmp, \in, #63
vand \tmp, \tmp, \const
@ -888,6 +908,7 @@ ENDPROC(aesbs_ctr_encrypt)
* aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
* int blocks, u8 iv[], int reorder_last_tweak)
*/
.align 6
__xts_prepare8:
vld1.8 {q14}, [r7] // load iv
vmov.i32 d30, #0x87 // compose tweak mask vector

View File

@ -37,7 +37,7 @@ asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
int rounds, int blocks, u8 iv[]);
asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
int rounds, int blocks, u8 ctr[], u8 final[]);
int rounds, int blocks, u8 ctr[]);
asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
int rounds, int blocks, u8 iv[], int);
@ -243,32 +243,25 @@ static int ctr_encrypt(struct skcipher_request *req)
err = skcipher_walk_virt(&walk, req, false);
while (walk.nbytes > 0) {
unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL;
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
int bytes = walk.nbytes;
if (walk.nbytes < walk.total) {
blocks = round_down(blocks,
walk.stride / AES_BLOCK_SIZE);
final = NULL;
}
if (unlikely(bytes < AES_BLOCK_SIZE))
src = dst = memcpy(buf + sizeof(buf) - bytes,
src, bytes);
else if (walk.nbytes < walk.total)
bytes &= ~(8 * AES_BLOCK_SIZE - 1);
kernel_neon_begin();
aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
ctx->rk, ctx->rounds, blocks, walk.iv, final);
aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
kernel_neon_end();
if (final) {
u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE;
u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE;
if (unlikely(bytes < AES_BLOCK_SIZE))
memcpy(walk.dst.virt.addr,
buf + sizeof(buf) - bytes, bytes);
crypto_xor_cpy(dst, src, final,
walk.total % AES_BLOCK_SIZE);
err = skcipher_walk_done(&walk, 0);
break;
}
err = skcipher_walk_done(&walk,
walk.nbytes - blocks * AES_BLOCK_SIZE);
err = skcipher_walk_done(&walk, walk.nbytes - bytes);
}
return err;