ggml : optimize llamafile cpu matrix multiplication for ppc64le (#10156)
This change upstreams llamafile's cpu matrix multiplication kernels for ppc64le using MMA builtins for FP32 datatype. This change results in a consistent 90% improvement in input processing time, and 20% to 80% improvement in output processing time, across various batch sizes. The patch is tested with Meta-Lllama-3-8B, Mistral-7B, Llama-2-7B-chat-hf models on a IBM POWER10 machine. Signed-off-by: Amrita H S <amritahs@linux.vnet.ibm.com>
This commit is contained in:
		
							parent
							
								
									8fc393f246
								
							
						
					
					
						commit
						e89213492d
					
				
					 2 changed files with 615 additions and 2 deletions
				
			
		|  | @ -1265,8 +1265,13 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW | |||
|     endif() | ||||
| elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") | ||||
|     message(STATUS "PowerPC detected") | ||||
|     if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") | ||||
|         list(APPEND ARCH_FLAGS -mcpu=powerpc64le) | ||||
|     execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1" | ||||
|                    OUTPUT_VARIABLE POWER10_M) | ||||
|     string(FIND ${POWER10_M} "POWER10" substring_index) | ||||
|     if(${substring_index} GREATER_EQUAL 0) | ||||
|        list(APPEND ARCH_FLAGS -mcpu=power10) | ||||
|     elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le") | ||||
|        list(APPEND ARCH_FLAGS -mcpu=powerpc64le) | ||||
|     else() | ||||
|         list(APPEND ARCH_FLAGS -mcpu=native -mtune=native) | ||||
|         #TODO: Add  targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be) | ||||
|  |  | |||
|  | @ -106,6 +106,10 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } | |||
| inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } | ||||
| #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 | ||||
| 
 | ||||
| #if defined(__MMA__) | ||||
| typedef vector unsigned char vec_t; | ||||
| typedef __vector_quad acc_t; | ||||
| #endif | ||||
| ////////////////////////////////////////////////////////////////////////////////////////////////////
 | ||||
| // VECTORIZED FUSED MULTIPLY ADD
 | ||||
| 
 | ||||
|  | @ -1026,6 +1030,600 @@ class tinyBLAS_Q0_AVX { | |||
| }; | ||||
| #endif // __AVX__
 | ||||
| 
 | ||||
| //PPC Implementation
 | ||||
| #if defined(__MMA__) | ||||
| 
 | ||||
| #define SAVE_ACC(ACC, ii, jj) \ | ||||
|    __builtin_mma_disassemble_acc(vec_C, ACC); \ | ||||
|    for (int I = 0; I < 4; I++) { \ | ||||
|       for (int J = 0; J < 4; J++) { \ | ||||
|          *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \ | ||||
|       } \ | ||||
|    } \ | ||||
| 
 | ||||
| template <typename TA, typename TB, typename TC> | ||||
| class tinyBLAS_PPC { | ||||
|   public: | ||||
|     tinyBLAS_PPC(int64_t k, | ||||
|                 const TA *A, int64_t lda, | ||||
|                 const TB *B, int64_t ldb, | ||||
|                 TC *C, int64_t ldc, | ||||
|                 int ith, int nth) | ||||
|         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { | ||||
|     } | ||||
| 
 | ||||
|     void matmul(int64_t m, int64_t n) { | ||||
|        mnpack(0, m, 0, n); | ||||
|     } | ||||
| 
 | ||||
|   private: | ||||
| 
 | ||||
|     void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); | ||||
| 
 | ||||
|     void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) { | ||||
|         int64_t i, j; | ||||
|         float *aoffset = NULL, *boffset = NULL; | ||||
|         float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; | ||||
|         float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; | ||||
| 
 | ||||
|         aoffset = const_cast<float*>(a); | ||||
|         boffset = vec; | ||||
|         j = (rows >> 3); | ||||
|         if (j > 0) { | ||||
|             do { | ||||
|                 aoffset1 = aoffset; | ||||
|                 aoffset2 = aoffset1 + lda; | ||||
|                 aoffset3 = aoffset2 + lda; | ||||
|                 aoffset4 = aoffset3 + lda; | ||||
|                 aoffset5 = aoffset4 + lda; | ||||
|                 aoffset6 = aoffset5 + lda; | ||||
|                 aoffset7 = aoffset6 + lda; | ||||
|                 aoffset8 = aoffset7 + lda; | ||||
|                 aoffset += 8 * lda; | ||||
|                 i = (cols >> 3); | ||||
|                 if (i > 0) { | ||||
|                     __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; | ||||
|                     vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2]; | ||||
|                     vector float t1, t2, t3, t4, t5, t6, t7, t8; | ||||
|                     do { | ||||
|                         C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); | ||||
|                         C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); | ||||
|                         C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); | ||||
|                         C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); | ||||
|                         C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5); | ||||
|                         C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6); | ||||
|                         C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7); | ||||
|                         C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8); | ||||
|                         __builtin_vsx_disassemble_pair(c1, &C1); | ||||
|                         __builtin_vsx_disassemble_pair(c2, &C2); | ||||
|                         __builtin_vsx_disassemble_pair(c3, &C3); | ||||
|                         __builtin_vsx_disassemble_pair(c4, &C4); | ||||
|                         __builtin_vsx_disassemble_pair(c5, &C5); | ||||
|                         __builtin_vsx_disassemble_pair(c6, &C6); | ||||
|                         __builtin_vsx_disassemble_pair(c7, &C7); | ||||
|                         __builtin_vsx_disassemble_pair(c8, &C8); | ||||
| 
 | ||||
|                         t1 = vec_mergeh(c1[0], c2[0]); | ||||
|                         t2 = vec_mergeh(c3[0], c4[0]); | ||||
|                         t3 = vec_mergeh(c5[0], c6[0]); | ||||
|                         t4 = vec_mergeh(c7[0], c8[0]); | ||||
|                         t5 = vec_xxpermdi(t1, t2, 0); | ||||
|                         t6 = vec_xxpermdi(t3, t4, 0); | ||||
|                         t7 = vec_xxpermdi(t1, t2, 3); | ||||
|                         t8 = vec_xxpermdi(t3, t4, 3); | ||||
|                         vec_xst(t5, 0, boffset); | ||||
|                         vec_xst(t6, 0, boffset+4); | ||||
|                         vec_xst(t7, 0, boffset+8); | ||||
|                         vec_xst(t8, 0, boffset+12); | ||||
| 
 | ||||
|                         t1 = vec_mergel(c1[0], c2[0]); | ||||
|                         t2 = vec_mergel(c3[0], c4[0]); | ||||
|                         t3 = vec_mergel(c5[0], c6[0]); | ||||
|                         t4 = vec_mergel(c7[0], c8[0]); | ||||
|                         t5 = vec_xxpermdi(t1, t2, 0); | ||||
|                         t6 = vec_xxpermdi(t3, t4, 0); | ||||
|                         t7 = vec_xxpermdi(t1, t2, 3); | ||||
|                         t8 = vec_xxpermdi(t3, t4, 3); | ||||
|                         vec_xst(t5, 0, boffset+16); | ||||
|                         vec_xst(t6, 0, boffset+20); | ||||
|                         vec_xst(t7, 0, boffset+24); | ||||
|                         vec_xst(t8, 0, boffset+28); | ||||
| 
 | ||||
|                         t1 = vec_mergeh(c1[1], c2[1]); | ||||
|                         t2 = vec_mergeh(c3[1], c4[1]); | ||||
|                         t3 = vec_mergeh(c5[1], c6[1]); | ||||
|                         t4 = vec_mergeh(c7[1], c8[1]); | ||||
|                         t5 = vec_xxpermdi(t1, t2, 0); | ||||
|                         t6 = vec_xxpermdi(t3, t4, 0); | ||||
|                         t7 = vec_xxpermdi(t1, t2, 3); | ||||
|                         t8 = vec_xxpermdi(t3, t4, 3); | ||||
|                         vec_xst(t5, 0, boffset+32); | ||||
|                         vec_xst(t6, 0, boffset+36); | ||||
|                         vec_xst(t7, 0, boffset+40); | ||||
|                         vec_xst(t8, 0, boffset+44); | ||||
| 
 | ||||
|                         t1 = vec_mergel(c1[1], c2[1]); | ||||
|                         t2 = vec_mergel(c3[1], c4[1]); | ||||
|                         t3 = vec_mergel(c5[1], c6[1]); | ||||
|                         t4 = vec_mergel(c7[1], c8[1]); | ||||
|                         t5 = vec_xxpermdi(t1, t2, 0); | ||||
|                         t6 = vec_xxpermdi(t3, t4, 0); | ||||
|                         t7 = vec_xxpermdi(t1, t2, 3); | ||||
|                         t8 = vec_xxpermdi(t3, t4, 3); | ||||
|                         vec_xst(t5, 0, boffset+48); | ||||
|                         vec_xst(t6, 0, boffset+52); | ||||
|                         vec_xst(t7, 0, boffset+56); | ||||
|                         vec_xst(t8, 0, boffset+60); | ||||
| 
 | ||||
|                         aoffset1 += 8*lda; | ||||
|                         aoffset2 += 8*lda; | ||||
|                         aoffset3 += 8*lda; | ||||
|                         aoffset4 += 8*lda; | ||||
|                         boffset += 64; | ||||
|                         i--; | ||||
|                     } while(i > 0); | ||||
|                 } | ||||
|                 if (cols & 4) { | ||||
|                     vector float c1, c2, c3, c4, c5, c6, c7, c8; | ||||
|                     vector float t1, t2, t3, t4, t5, t6, t7, t8; | ||||
|                     c1 = vec_xl(0, aoffset1); | ||||
|                     c2 = vec_xl(0, aoffset2); | ||||
|                     c3 = vec_xl(0, aoffset3); | ||||
|                     c4 = vec_xl(0, aoffset4); | ||||
|                     c5 = vec_xl(0, aoffset5); | ||||
|                     c6 = vec_xl(0, aoffset6); | ||||
|                     c7 = vec_xl(0, aoffset7); | ||||
|                     c8 = vec_xl(0, aoffset8); | ||||
| 
 | ||||
|                     t1 = vec_mergeh(c1, c2); | ||||
|                     t2 = vec_mergeh(c3, c4); | ||||
|                     t3 = vec_mergeh(c5, c6); | ||||
|                     t4 = vec_mergeh(c7, c8); | ||||
|                     t5 = vec_xxpermdi(t1, t2, 0); | ||||
|                     t6 = vec_xxpermdi(t3, t4, 0); | ||||
|                     t7 = vec_xxpermdi(t1, t2, 3); | ||||
|                     t8 = vec_xxpermdi(t3, t4, 3); | ||||
|                     vec_xst(t5, 0, boffset); | ||||
|                     vec_xst(t6, 0, boffset+4); | ||||
|                     vec_xst(t7, 0, boffset+8); | ||||
|                     vec_xst(t8, 0, boffset+12); | ||||
| 
 | ||||
|                     t1 = vec_mergel(c1, c2); | ||||
|                     t2 = vec_mergel(c3, c4); | ||||
|                     t3 = vec_mergel(c5, c6); | ||||
|                     t4 = vec_mergel(c7, c8); | ||||
|                     t5 = vec_xxpermdi(t1, t2, 0); | ||||
|                     t6 = vec_xxpermdi(t3, t4, 0); | ||||
|                     t7 = vec_xxpermdi(t1, t2, 3); | ||||
|                     t8 = vec_xxpermdi(t3, t4, 3); | ||||
|                     vec_xst(t5, 0, boffset+16); | ||||
|                     vec_xst(t6, 0, boffset+20); | ||||
|                     vec_xst(t7, 0, boffset+24); | ||||
|                     vec_xst(t8, 0, boffset+28); | ||||
|                 } | ||||
|             j--; | ||||
|             } while(j > 0); | ||||
|         } | ||||
| 
 | ||||
|         if (rows & 4) { | ||||
|             aoffset1 = aoffset; | ||||
|             aoffset2 = aoffset1 + lda; | ||||
|             aoffset3 = aoffset2 + lda; | ||||
|             aoffset4 = aoffset3 + lda; | ||||
|             aoffset += 4 * lda; | ||||
|             i = (cols >> 3); | ||||
|             if (i > 0) { | ||||
|                 __vector_pair C1, C2, C3, C4; | ||||
|                 vector float c1[2], c2[2], c3[2], c4[2]; | ||||
|                 vector float t1, t2, t3, t4, t5, t6, t7, t8; | ||||
|                 do { | ||||
|                     C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); | ||||
|                     C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); | ||||
|                     C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); | ||||
|                     C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); | ||||
|                     __builtin_vsx_disassemble_pair(c1, &C1); | ||||
|                     __builtin_vsx_disassemble_pair(c2, &C2); | ||||
|                     __builtin_vsx_disassemble_pair(c3, &C3); | ||||
|                     __builtin_vsx_disassemble_pair(c4, &C4); | ||||
| 
 | ||||
|                     t1 = vec_mergeh(c1[0], c2[0]); | ||||
|                     t2 = vec_mergeh(c3[0], c4[0]); | ||||
|                     t3 = vec_mergel(c1[0], c2[0]); | ||||
|                     t4 = vec_mergel(c3[0], c4[0]); | ||||
|                     t5 = vec_xxpermdi(t1, t2, 0); | ||||
|                     t6 = vec_xxpermdi(t1, t2, 3); | ||||
|                     t7 = vec_xxpermdi(t3, t4, 0); | ||||
|                     t8 = vec_xxpermdi(t3, t4, 3); | ||||
|                     vec_xst(t5, 0, boffset); | ||||
|                     vec_xst(t6, 0, boffset+4); | ||||
|                     vec_xst(t7, 0, boffset+8); | ||||
|                     vec_xst(t8, 0, boffset+12); | ||||
| 
 | ||||
|                     t1 = vec_mergeh(c1[1], c2[1]); | ||||
|                     t2 = vec_mergeh(c3[1], c4[1]); | ||||
|                     t3 = vec_mergel(c1[1], c2[1]); | ||||
|                     t4 = vec_mergel(c3[1], c4[1]); | ||||
|                     t5 = vec_xxpermdi(t1, t2, 0); | ||||
|                     t6 = vec_xxpermdi(t1, t2, 3); | ||||
|                     t7 = vec_xxpermdi(t3, t4, 0); | ||||
|                     t8 = vec_xxpermdi(t3, t4, 3); | ||||
|                     vec_xst(t5, 0, boffset+16); | ||||
|                     vec_xst(t6, 0, boffset+20); | ||||
|                     vec_xst(t7, 0, boffset+24); | ||||
|                     vec_xst(t8, 0, boffset+28); | ||||
| 
 | ||||
|                     aoffset1 += 8*lda; | ||||
|                     aoffset2 += 8*lda; | ||||
|                     aoffset3 += 8*lda; | ||||
|                     aoffset4 += 8*lda; | ||||
|                     boffset += 32; | ||||
|                     i--; | ||||
|                 } while(i > 0); | ||||
|             } | ||||
| 
 | ||||
|             if (cols & 4) { | ||||
|                 vector float c1, c2, c3, c4; | ||||
|                 vector float t1, t2, t3, t4; | ||||
|                 c1 = vec_xl(0, aoffset1); | ||||
|                 c2 = vec_xl(0, aoffset2); | ||||
|                 c3 = vec_xl(0, aoffset3); | ||||
|                 c4 = vec_xl(0, aoffset4); | ||||
| 
 | ||||
|                 t1 = vec_mergeh(c1, c2); | ||||
|                 t2 = vec_mergeh(c3, c4); | ||||
|                 t3 = vec_xxpermdi(t1, t2, 0); | ||||
|                 t4 = vec_xxpermdi(t1, t2, 3); | ||||
|                 vec_xst(t3, 0, boffset); | ||||
|                 vec_xst(t4, 0, boffset+4); | ||||
| 
 | ||||
|                 t1 = vec_mergel(c1, c2); | ||||
|                 t2 = vec_mergel(c3, c4); | ||||
|                 t3 = vec_xxpermdi(t1, t2, 0); | ||||
|                 t4 = vec_xxpermdi(t1, t2, 3); | ||||
|                 vec_xst(t3, 0, boffset+8); | ||||
|                 vec_xst(t4, 0, boffset+12); | ||||
|             } | ||||
|         } | ||||
|         if (rows & 3) { | ||||
|             aoffset1 = aoffset; | ||||
|             aoffset2 = aoffset1 + lda; | ||||
|             aoffset3 = aoffset2 + lda; | ||||
|             if (cols & 4) { | ||||
|                 vector float c1, c2, c3, c4 = {0}; | ||||
|                 vector float t1, t2, t3, t4; | ||||
|                 c1 = vec_xl(0, aoffset1); | ||||
|                 c2 = vec_xl(0, aoffset2); | ||||
|                 c3 = vec_xl(0, aoffset3); | ||||
| 
 | ||||
|                 t1 = vec_mergeh(c1, c2); | ||||
|                 t2 = vec_mergeh(c3, c4); | ||||
|                 t3 = vec_xxpermdi(t1, t2, 0); | ||||
|                 t4 = vec_xxpermdi(t1, t2, 3); | ||||
|                 vec_xst(t3, 0, boffset); | ||||
|                 vec_xst(t4, 0, boffset+4); | ||||
| 
 | ||||
|                 t1 = vec_mergel(c1, c2); | ||||
|                 t2 = vec_mergel(c3, c4); | ||||
|                 t3 = vec_xxpermdi(t1, t2, 0); | ||||
|                 t4 = vec_xxpermdi(t1, t2, 3); | ||||
|                 vec_xst(t3, 0, boffset+8); | ||||
|                 vec_xst(t4, 0, boffset+12); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     void KERNEL_4x4(int64_t ii, int64_t jj) { | ||||
|         vec_t vec_A[4], vec_B[4], vec_C[4]; | ||||
|         acc_t acc_0; | ||||
|         __builtin_mma_xxsetaccz(&acc_0); | ||||
|         for (int l = 0; l < k; l+=4) { | ||||
|             READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); | ||||
|             READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); | ||||
|         } | ||||
|         SAVE_ACC(&acc_0, ii, jj); | ||||
|     } | ||||
| 
 | ||||
|     void KERNEL_4x8(int64_t ii, int64_t jj) { | ||||
|         vec_t vec_A[4], vec_B[8], vec_C[4]; | ||||
|         acc_t acc_0, acc_1; | ||||
|         __builtin_mma_xxsetaccz(&acc_0); | ||||
|         __builtin_mma_xxsetaccz(&acc_1); | ||||
|         for (int64_t l = 0; l < k; l+=4) { | ||||
|             READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); | ||||
|             READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]); | ||||
|         } | ||||
|         SAVE_ACC(&acc_0, ii, jj); | ||||
|         SAVE_ACC(&acc_1, ii, jj+4); | ||||
|     } | ||||
| 
 | ||||
|     void KERNEL_8x4(int64_t ii, int64_t jj) { | ||||
|         vec_t vec_A[8], vec_B[4], vec_C[4]; | ||||
|         acc_t acc_0, acc_1; | ||||
|         __builtin_mma_xxsetaccz(&acc_0); | ||||
|         __builtin_mma_xxsetaccz(&acc_1); | ||||
|         for (int64_t l = 0; l < k; l+=4) { | ||||
|             READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); | ||||
|             READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]); | ||||
|             __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]); | ||||
|         } | ||||
|         SAVE_ACC(&acc_0, ii, jj); | ||||
|         SAVE_ACC(&acc_1, ii+4, jj); | ||||
|     } | ||||
| 
 | ||||
|     void KERNEL_8x8(int64_t ii, int64_t jj) { | ||||
|         vec_t vec_A[16], vec_B[16], vec_C[4]; | ||||
|         acc_t acc_0, acc_1, acc_2, acc_3; | ||||
|         __builtin_mma_xxsetaccz(&acc_0); | ||||
|         __builtin_mma_xxsetaccz(&acc_1); | ||||
|         __builtin_mma_xxsetaccz(&acc_2); | ||||
|         __builtin_mma_xxsetaccz(&acc_3); | ||||
|         for (int l = 0; l < k; l+=8) { | ||||
|             READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); | ||||
|             READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); | ||||
|             for(int x = 0; x < 16; x+=2) { | ||||
|                 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); | ||||
|                 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); | ||||
|                 __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]); | ||||
|                 __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]); | ||||
|             } | ||||
|         } | ||||
|         SAVE_ACC(&acc_0, ii, jj); | ||||
|         SAVE_ACC(&acc_1, ii, jj+4); | ||||
|         SAVE_ACC(&acc_2, ii+4, jj); | ||||
|         SAVE_ACC(&acc_3, ii+4, jj+4); | ||||
|     } | ||||
| 
 | ||||
|     void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { | ||||
|         int64_t mc, nc, mp, np; | ||||
|         int m_rem = MIN(m - m0, 16); | ||||
|         int n_rem = MIN(n - n0, 16); | ||||
|         if (m_rem >= 16 && n_rem >= 8) { | ||||
|             mc = 8; | ||||
|             nc = 8; | ||||
|             gemm<8,8>(m0, m, n0, n); | ||||
|         } else if(m_rem >= 8 && n_rem >= 16) { | ||||
|             mc = 8; | ||||
|             nc = 8; | ||||
|             gemm<8,8>(m0, m, n0, n); | ||||
|         } else if (m_rem >= 8 && n_rem >= 8) { | ||||
|             mc = 8; | ||||
|             nc = 8; | ||||
|             gemm<8,8>(m0, m, n0, n); | ||||
|         } else if (m_rem >= 4 && n_rem >= 8) { | ||||
|             mc = 4; | ||||
|             nc = 8; | ||||
|             gemm<4,8>(m0, m, n0, n); | ||||
|         } else if (m_rem >= 8 && n_rem >= 4) { | ||||
|             mc = 8; | ||||
|             nc = 4; | ||||
|             gemm<8,4>(m0, m, n0, n); | ||||
|         } else if (m_rem >= 4 && n_rem >= 4) { | ||||
|             mc = 4; | ||||
|             nc = 4; | ||||
|             gemm<4,4>(m0, m, n0, n); | ||||
|         } else if ((m_rem < 4) && (n_rem > 4)) { | ||||
|             nc = 4; | ||||
|             switch(m_rem) { | ||||
|                 case 1: | ||||
|                     mc = 1; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 2: | ||||
|                     mc = 2; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 3: | ||||
|                     mc = 3; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 default: | ||||
|                     return; | ||||
|             } | ||||
|         } else if ((m_rem > 4) && (n_rem < 4)) { | ||||
|             mc = 4; | ||||
|             switch(n_rem) { | ||||
|                 case 1: | ||||
|                     nc = 1; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 2: | ||||
|                     nc = 2; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 3: | ||||
|                     nc = 3; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 default: | ||||
|                     return; | ||||
|             } | ||||
|         } else { | ||||
|             switch((m_rem << 4) | n_rem) { | ||||
|                 case 0x43: | ||||
|                     mc = 4; | ||||
|                     nc = 3; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x42: | ||||
|                     mc = 4; | ||||
|                     nc = 2; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x41: | ||||
|                     mc = 4; | ||||
|                     nc = 1; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x34: | ||||
|                     mc = 3; | ||||
|                     nc = 4; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x33: | ||||
|                     mc = 3; | ||||
|                     nc = 3; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x32: | ||||
|                     mc = 3; | ||||
|                     nc = 2; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x31: | ||||
|                     mc = 3; | ||||
|                     nc = 1; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x24: | ||||
|                     mc = 2; | ||||
|                     nc = 4; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x23: | ||||
|                     mc = 2; | ||||
|                     nc = 3; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x22: | ||||
|                     mc = 2; | ||||
|                     nc = 2; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x21: | ||||
|                     mc = 2; | ||||
|                     nc = 1; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x14: | ||||
|                     mc = 1; | ||||
|                     nc = 4; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x13: | ||||
|                     mc = 1; | ||||
|                     nc = 3; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x12: | ||||
|                     mc = 1; | ||||
|                     nc = 2; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 case 0x11: | ||||
|                     mc = 1; | ||||
|                     nc = 1; | ||||
|                     gemm_small(m0, m, n0, n, mc, nc); | ||||
|                     break; | ||||
|                 default: | ||||
|                     return; | ||||
|             } | ||||
|         } | ||||
|         mp = m0 + (m - m0) / mc * mc; | ||||
|         np = n0 + (n - n0) / nc * nc; | ||||
|         mnpack(mp, m, n0, np); | ||||
|         mnpack(m0, m, np, n); | ||||
|     } | ||||
| 
 | ||||
|      void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { | ||||
|         int64_t ytiles = (m - m0) / RM; | ||||
|         int64_t xtiles = (n - n0) / RN; | ||||
|         int64_t tiles = xtiles * ytiles; | ||||
|         int64_t duty = (tiles + nth - 1) / nth; | ||||
|         int64_t start = duty * ith; | ||||
|         int64_t end = start + duty; | ||||
|         if (end > tiles) | ||||
|             end = tiles; | ||||
|         for (int64_t job = start; job < end; ++job) { | ||||
|             int64_t ii = m0 + job / xtiles * RM; | ||||
|             int64_t jj = n0 + job % xtiles * RN; | ||||
|             vec_t vec_C[4]; | ||||
|             acc_t acc_0; | ||||
|             __builtin_mma_xxsetaccz(&acc_0); | ||||
|             vec_t vec_A[4], vec_B[4]; | ||||
|             for (int l=0; l<k; l+=4) { | ||||
|                 if (RN >= 4 && RM == 1) { | ||||
|                     float* a = const_cast<float*>(A+(ii)*lda+l); | ||||
|                     READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); | ||||
|                     vec_A[0] = (vec_t)vec_xl(0,a); | ||||
|                     vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1)); | ||||
|                     vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2)); | ||||
|                     vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3)); | ||||
|                 } else { | ||||
|                     READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); | ||||
|                     READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); | ||||
|                 } | ||||
|                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); | ||||
|                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); | ||||
|                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); | ||||
|                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]); | ||||
|             } | ||||
|             __builtin_mma_disassemble_acc(vec_C, &acc_0); | ||||
|             for (int I = 0; I < RM; I++) { | ||||
|                 for (int J = 0; J < RN; J++) { | ||||
|                     *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); | ||||
|                 } | ||||
|             } | ||||
|        } | ||||
|     } | ||||
| 
 | ||||
|     template <int RM, int RN> | ||||
|     NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { | ||||
|         int64_t ytiles = (m - m0) / RM; | ||||
|         int64_t xtiles = (n - n0) / RN; | ||||
|         int64_t tiles = xtiles * ytiles; | ||||
|         int64_t duty = (tiles + nth - 1) / nth; | ||||
|         int64_t start = duty * ith; | ||||
|         int64_t end = start + duty; | ||||
|         if (RM == 4 && RN == 4) { | ||||
|             kernel = &tinyBLAS_PPC::KERNEL_4x4; | ||||
|         } else if (RM == 4 && RN == 8) { | ||||
|             kernel = &tinyBLAS_PPC::KERNEL_4x8; | ||||
|         } else if (RM == 8 && RN == 4) { | ||||
|             kernel = &tinyBLAS_PPC::KERNEL_8x4; | ||||
|         } else if (RM == 8 && RN == 8) { | ||||
|             kernel = &tinyBLAS_PPC::KERNEL_8x8; | ||||
|         } | ||||
|         if (end > tiles) | ||||
|             end = tiles; | ||||
|         for (int64_t job = start; job < end; ++job) { | ||||
|             int64_t ii = m0 + job / xtiles * RM; | ||||
|             int64_t jj = n0 + job % xtiles * RN; | ||||
|             (this->*kernel)(ii, jj); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     const TA *const A; | ||||
|     const TB *const B; | ||||
|     TC *C; | ||||
|     TA *At; | ||||
|     TB *Bt; | ||||
|     const int64_t k; | ||||
|     const int64_t lda; | ||||
|     const int64_t ldb; | ||||
|     const int64_t ldc; | ||||
|     const int ith; | ||||
|     const int nth; | ||||
| }; | ||||
| #endif | ||||
| } // namespace
 | ||||
| 
 | ||||
| /**
 | ||||
|  | @ -1114,6 +1712,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda | |||
|             ith, nth}; | ||||
|         tb.matmul(m, n); | ||||
|         return true; | ||||
| #elif defined(__MMA__) | ||||
|         if (k % 8) | ||||
|             return false; | ||||
|         tinyBLAS_PPC<float, float, float> tb{ | ||||
|             k, (const float *)A, lda, | ||||
|             (const float *)B, ldb, | ||||
|             (float *)C, ldc, | ||||
|             ith, nth}; | ||||
|         tb.matmul(m, n); | ||||
|         return true; | ||||
| #else | ||||
|         return false; | ||||
| #endif | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue