Make updates to reduce number of load instructions

This commit is contained in:
Srihari-mcw 2024-08-28 00:10:26 -07:00 committed by Srihari-mcw
parent 364dc964ba
commit c950fc3064

View file

@ -2504,22 +2504,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
for (int rp = 0; rp < 4; rp++) { for (int rp = 0; rp < 4; rp++) {
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
__m256i lhs_mat_01_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs))); __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs)));
lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_01_0, lhs_mat_01_0, 0); __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
__m256i lhs_mat_23_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 16))); __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_23_0, lhs_mat_23_0, 0); __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 32)));
__m256i lhs_mat_01_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 32))); __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_01_1, lhs_mat_01_1, 0); __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
__m256i lhs_mat_23_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 48))); __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 64)));
lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_23_1, lhs_mat_23_1, 0); __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
__m256i lhs_mat_01_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 64))); __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_01_2, lhs_mat_01_2, 0); __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptrs[rp][b].qs + 96)));
__m256i lhs_mat_23_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 80))); __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_23_2, lhs_mat_23_2, 0); __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
__m256i lhs_mat_01_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 96)));
lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_01_3, lhs_mat_01_3, 0);
__m256i lhs_mat_23_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptrs[rp][b].qs + 112)));
lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_23_3, lhs_mat_23_3, 0);
// Shuffle pattern one - left side input // Shuffle pattern one - left side input
const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) const __m256i lhs_mat_01_0_sp1 = _mm256_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3)
@ -2670,22 +2666,18 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 // Load the four block_q4_0 quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector // Loaded as set of 128 bit vectors and repeated into a 256 bit vector
__m256i lhs_mat_01_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs))); __m256i lhs_mat_0123_0 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs)));
lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_01_0, lhs_mat_01_0, 0); __m256i lhs_mat_01_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 0);
__m256i lhs_mat_23_0 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16))); __m256i lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_0123_0, lhs_mat_0123_0, 17);
lhs_mat_23_0 = _mm256_permute2f128_si256(lhs_mat_23_0, lhs_mat_23_0, 0); __m256i lhs_mat_0123_1 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 32)));
__m256i lhs_mat_01_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32))); __m256i lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 0);
lhs_mat_01_1 = _mm256_permute2f128_si256(lhs_mat_01_1, lhs_mat_01_1, 0); __m256i lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_0123_1, lhs_mat_0123_1, 17);
__m256i lhs_mat_23_1 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48))); __m256i lhs_mat_0123_2 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 64)));
lhs_mat_23_1 = _mm256_permute2f128_si256(lhs_mat_23_1, lhs_mat_23_1, 0); __m256i lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 0);
__m256i lhs_mat_01_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 64))); __m256i lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_0123_2, lhs_mat_0123_2, 17);
lhs_mat_01_2 = _mm256_permute2f128_si256(lhs_mat_01_2, lhs_mat_01_2, 0); __m256i lhs_mat_0123_3 = _mm256_loadu_si256((const __m256i *)((a_ptr[b].qs + 96)));
__m256i lhs_mat_23_2 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 80))); __m256i lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 0);
lhs_mat_23_2 = _mm256_permute2f128_si256(lhs_mat_23_2, lhs_mat_23_2, 0); __m256i lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_0123_3, lhs_mat_0123_3, 17);
__m256i lhs_mat_01_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 96)));
lhs_mat_01_3 = _mm256_permute2f128_si256(lhs_mat_01_3, lhs_mat_01_3, 0);
__m256i lhs_mat_23_3 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 112)));
lhs_mat_23_3 = _mm256_permute2f128_si256(lhs_mat_23_3, lhs_mat_23_3, 0);
// Shuffle pattern one - left side input // Shuffle pattern one - left side input