From c950fc306424961e421461d93ff830b4b404d223 Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Wed, 28 Aug 2024 00:10:26 -0700 Subject: [PATCH] Make updates to reduce number of load instructions --- ggml/src/ggml-aarch64.c | 56 ++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index a17d3c70c..72cb83c9b 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -2504,22 +2504,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * for (int rp = 0; rp < 4; rp++) { // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 // Loaded as set of 128 bit vectors and repeated into a 256 bit vector - __m256i lhs_mat_01_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs))); - lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_01_0, lhs_mat_01_0, 0); - __m256i lhs_mat_23_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 16))); - lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_23_0, lhs_mat_23_0, 0); - __m256i lhs_mat_01_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 32))); - lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_01_1, lhs_mat_01_1, 0); - __m256i lhs_mat_23_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 48))); - lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_23_1, lhs_mat_23_1, 0); - __m256i lhs_mat_01_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 64))); - lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_01_2, lhs_mat_01_2, 0); - __m256i lhs_mat_23_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 80))); - lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_23_2, lhs_mat_23_2, 0); - __m256i lhs_mat_01_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 96))); - lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_01_3, lhs_mat_01_3, 0); - __m256i lhs_mat_23_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 112))); - lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_23_3, lhs_mat_23_3, 0); + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); // Shuffle pattern one - left side input const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) @@ -2670,22 +2666,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 // Loaded as set of 128 bit vectors and repeated into a 256 bit vector - __m256i lhs_mat_01_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs))); - lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_01_0, lhs_mat_01_0, 0); - __m256i lhs_mat_23_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16))); - lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_23_0, lhs_mat_23_0, 0); - __m256i lhs_mat_01_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32))); - lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_01_1, lhs_mat_01_1, 0); - __m256i lhs_mat_23_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48))); - lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_23_1, lhs_mat_23_1, 0); - __m256i lhs_mat_01_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 64))); - lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_01_2, lhs_mat_01_2, 0); - __m256i lhs_mat_23_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 80))); - lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_23_2, lhs_mat_23_2, 0); - __m256i lhs_mat_01_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 96))); - lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_01_3, lhs_mat_01_3, 0); - __m256i lhs_mat_23_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 112))); - lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_23_3, lhs_mat_23_3, 0); + __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs))); + __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0); + __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17); + __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32))); + __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0); + __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17); + __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64))); + __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0); + __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17); + __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96))); + __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0); + __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17); // Shuffle pattern one - left side input