Faster Q3_K implementation on Metal (#2307)
* Faster Q3_K on Metal * Additional Q3_K speedup on Metal * Q3_K for QK_K = 64 * Better Q3_K for QK_K = 64 21.6 ms/t -> 21.1 ms/t --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
		
							parent
							
								
									0db14fef06
								
							
						
					
					
						commit
						4d76a5f49b
					
				
					 2 changed files with 125 additions and 82 deletions
				
			
		
							
								
								
									
										15
									
								
								ggml-metal.m
									
										
									
									
									
								
							
							
						
						
									
										15
									
								
								ggml-metal.m
									
										
									
									
									
								
							|  | @ -685,8 +685,8 @@ void ggml_metal_graph_compute( | ||||||
|                                             GGML_ASSERT(ne02 == 1); |                                             GGML_ASSERT(ne02 == 1); | ||||||
|                                             GGML_ASSERT(ne12 == 1); |                                             GGML_ASSERT(ne12 == 1); | ||||||
| 
 | 
 | ||||||
|                                             nth0 = 4; |                                             nth0 = 2; | ||||||
|                                             nth1 = 16; |                                             nth1 = 32; | ||||||
|                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; |                                             [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; | ||||||
|                                         } break; |                                         } break; | ||||||
|                                     case GGML_TYPE_Q4_K: |                                     case GGML_TYPE_Q4_K: | ||||||
|  | @ -743,15 +743,18 @@ void ggml_metal_graph_compute( | ||||||
|                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { |                                     src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { | ||||||
|                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||||
|                                 } |                                 } | ||||||
|  |                                 else if (src0t == GGML_TYPE_Q3_K) { | ||||||
|  | #ifdef GGML_QKK_64 | ||||||
|  |                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||||
|  | #else | ||||||
|  |                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||||
|  | #endif | ||||||
|  |                                 } | ||||||
|                                 else if (src0t == GGML_TYPE_Q5_K) { |                                 else if (src0t == GGML_TYPE_Q5_K) { | ||||||
|                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||||
|                                 } |                                 } | ||||||
|                                 else if (src0t == GGML_TYPE_Q6_K) { |                                 else if (src0t == GGML_TYPE_Q6_K) { | ||||||
|                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||||
|                                 } |  | ||||||
|                                 else if (src0t == GGML_TYPE_Q3_K) { |  | ||||||
|                                     [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; |  | ||||||
|                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |  | ||||||
|                                 } else { |                                 } else { | ||||||
|                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; |                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; | ||||||
|                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; |                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||||
|  |  | ||||||
							
								
								
									
										192
									
								
								ggml-metal.metal
									
										
									
									
									
								
							
							
						
						
									
										192
									
								
								ggml-metal.metal
									
										
									
									
									
								
							|  | @ -351,7 +351,7 @@ kernel void kernel_rms_norm( | ||||||
| 
 | 
 | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|     // broadcast, simd group number is ntg / 32
 |     // broadcast, simd group number is ntg / 32
 | ||||||
|     for (int i = ntg / 32 / 2; i > 0; i /= 2) { |     for (uint i = ntg / 32 / 2; i > 0; i /= 2) { | ||||||
|        if (tpitg < i) { |        if (tpitg < i) { | ||||||
|            sum[tpitg] += sum[tpitg + i]; |            sum[tpitg] += sum[tpitg + i]; | ||||||
|        } |        } | ||||||
|  | @ -1339,6 +1339,7 @@ kernel void kernel_mul_mat_q2_K_f32( | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | #if QK_K == 256 | ||||||
| kernel void kernel_mul_mat_q3_K_f32( | kernel void kernel_mul_mat_q3_K_f32( | ||||||
|         device const  void * src0, |         device const  void * src0, | ||||||
|         device const float * src1, |         device const float * src1, | ||||||
|  | @ -1347,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32( | ||||||
|         constant   int64_t & ne10, |         constant   int64_t & ne10, | ||||||
|         constant   int64_t & ne0, |         constant   int64_t & ne0, | ||||||
|         constant   int64_t & ne1, |         constant   int64_t & ne1, | ||||||
|         threadgroup float  * sum [[threadgroup(0)]], |  | ||||||
|         uint2 tgpig[[threadgroup_position_in_grid]], |         uint2 tgpig[[threadgroup_position_in_grid]], | ||||||
|         uint2 tpitg[[thread_position_in_threadgroup]], |         uint tiisg[[thread_index_in_simdgroup]], | ||||||
|         uint2  tptg[[threads_per_threadgroup]]) { |         uint sgitg[[simdgroup_index_in_threadgroup]]) { | ||||||
| 
 | 
 | ||||||
|     const int nb = ne00/QK_K; |     const int nb = ne00/QK_K; | ||||||
| 
 | 
 | ||||||
|     const int64_t r0 = tgpig.x; |     const int64_t r0 = tgpig.x; | ||||||
|     const int64_t r1 = tgpig.y; |     const int64_t r1 = tgpig.y; | ||||||
| 
 | 
 | ||||||
|     device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb; |     const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; | ||||||
|  | 
 | ||||||
|  |     device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb; | ||||||
|     device const float     * yy = (device const float      *) src1 + r1*ne10; |     device const float     * yy = (device const float      *) src1 + r1*ne10; | ||||||
| 
 | 
 | ||||||
|     const int nth = tptg.x*tptg.y; |     float yl[16]; | ||||||
|     const int ith = tptg.y*tpitg.x + tpitg.y; |  | ||||||
| 
 |  | ||||||
| #if QK_K == 256 |  | ||||||
| 
 |  | ||||||
|     const uint8_t m3 = 3; |  | ||||||
|     const int8_t  m4 = 4; |  | ||||||
| 
 | 
 | ||||||
|     const uint16_t kmask1 = 0x0303; |     const uint16_t kmask1 = 0x0303; | ||||||
|     const uint16_t kmask2 = 0x0f0f; |     const uint16_t kmask2 = 0x0f0f; | ||||||
| 
 | 
 | ||||||
|     const int tid = tpitg.y;        // expecting 16
 |     const int tid = tiisg/2; | ||||||
|  |     const int ix  = tiisg%2; | ||||||
|     const int ip  = tid/8;          // 0 or 1
 |     const int ip  = tid/8;          // 0 or 1
 | ||||||
|     const int il  = tid/2 - 4*ip;   // 0...3
 |     const int il  = tid/2 - 4*ip;   // 0...3
 | ||||||
|     const int ir  = tid%2; |     const int ir  = tid%2; | ||||||
|     const int n   = 8; |     const int n   = 8; | ||||||
|     const int l0  = n*ir; |     const int l0  = n*ir; | ||||||
| 
 | 
 | ||||||
|     const uint8_t m = 1 << (4*ip + il); |     const uint16_t m1 = 1 << (4*ip + il); | ||||||
|  |     const uint16_t m2 = m1 << 8; | ||||||
| 
 | 
 | ||||||
|     const int shift = 2*il; |     const int shift = 2*il; | ||||||
|  |     const uint16_t qm1 = 0x0003 << shift; | ||||||
|  |     const uint16_t qm2 = 0x0300 << shift; | ||||||
|  |     const int32_t v1 = 4 << shift; | ||||||
|  |     const int32_t v2 = 1024 << shift; | ||||||
| 
 | 
 | ||||||
|     const uint16_t s_shift1 = 4*ip; |     const uint16_t s_shift1 = 4*ip; | ||||||
|     const uint16_t s_shift2 = s_shift1 + 2*(il/2); |     const uint16_t s_shift2 = s_shift1 + 2*(il/2); | ||||||
|  | @ -1389,93 +1391,132 @@ kernel void kernel_mul_mat_q3_K_f32( | ||||||
|     const int q_offset = 32*ip + l0; |     const int q_offset = 32*ip + l0; | ||||||
|     const int y_offset = 128*ip + 32*il + l0; |     const int y_offset = 128*ip + 32*il + l0; | ||||||
| 
 | 
 | ||||||
|     //float sumf = 0;
 |     const int step = sizeof(block_q3_K) * nb / 2; | ||||||
|     float sumf1 = 0, sumf2 = 0; |  | ||||||
|     for (int i = tpitg.x; i < nb; i += tptg.x) { |  | ||||||
| 
 | 
 | ||||||
|         const float d_all = (float)(x[i].d); |     device const float * y1 = yy + ix*QK_K + y_offset; | ||||||
| 
 | 
 | ||||||
|         device const uint8_t * q = x[i].qs + q_offset; |     float sumf1[2] = {0.f}, sumf2[2] = {0.f}; | ||||||
|         device const uint8_t * h = x[i].hmask + l0; |     for (int i = ix; i < nb; i += 2) { | ||||||
|         device const float   * y = yy + i * QK_K + y_offset; |  | ||||||
| 
 | 
 | ||||||
|         device const uint16_t * a = (device const uint16_t *)x[i].scales; |         for (int l = 0; l < 8; ++l) { | ||||||
|         const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); |             yl[l+0] = y1[l+ 0]; | ||||||
| 
 |             yl[l+8] = y1[l+16]; | ||||||
|         float s = 0; |  | ||||||
|         for (int l = 0; l < n; ++l) { |  | ||||||
|             s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4)); |  | ||||||
|         } |         } | ||||||
|         float d = d_all * s; |  | ||||||
|         sumf1 += d * scales[0]; |  | ||||||
|         sumf2 += d; |  | ||||||
|         //sumf += d_all * s * (scales[0] - 32);
 |  | ||||||
| 
 | 
 | ||||||
|         s = 0; |         device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); | ||||||
|         for (int l = 0; l < n; ++l) { |         device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); | ||||||
|             s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4)); |         device const uint16_t * a = (device const uint16_t *)(x[i].scales); | ||||||
|  |         device const half * dh = &x[i].d; | ||||||
|  | 
 | ||||||
|  |         for (int row = 0; row < 2; ++row) { | ||||||
|  | 
 | ||||||
|  |             const float d_all = (float)dh[0]; | ||||||
|  |             const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); | ||||||
|  | 
 | ||||||
|  |             float s1 = 0, s2 = 0; | ||||||
|  |             for (int l = 0; l < n; l += 2) { | ||||||
|  |                 const uint16_t qs = q[l/2]; | ||||||
|  |                 s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); | ||||||
|  |                 s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); | ||||||
|  |             } | ||||||
|  |             float d = d_all * (s1 + 1.f/256.f * s2); | ||||||
|  |             sumf1[row] += d * scales[0]; | ||||||
|  |             sumf2[row] += d; | ||||||
|  | 
 | ||||||
|  |             s1 = s2 = 0; | ||||||
|  |             for (int l = 0; l < n; l += 2) { | ||||||
|  |                 const uint16_t qs = q[l/2+8]; | ||||||
|  |                 s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); | ||||||
|  |                 s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); | ||||||
|  |             } | ||||||
|  |             d = d_all * (s1 + 1.f/256.f * s2); | ||||||
|  |             sumf1[row] += d * scales[1]; | ||||||
|  |             sumf2[row] += d; | ||||||
|  | 
 | ||||||
|  |             q  += step; | ||||||
|  |             h  += step; | ||||||
|  |             a  += step; | ||||||
|  |             dh += step; | ||||||
|  | 
 | ||||||
|         } |         } | ||||||
|         d = d_all * s; | 
 | ||||||
|         sumf1 += d * scales[1]; |         y1 += 2 * QK_K; | ||||||
|         sumf2 += d; |  | ||||||
|         //sumf += d_all * s * (scales[1] - 32);
 |  | ||||||
| 
 | 
 | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     //sum[ith] = sumf;
 |     for (int row = 0; row < 2; ++row) { | ||||||
|     sum[ith] = sumf1 - 32.f*sumf2; |         const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); | ||||||
|  |         const float tot = simd_sum(sumf); | ||||||
|  |         if (tiisg == 0) { | ||||||
|  |             dst[r1*ne0 + first_row + row] = tot; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
| #else | #else | ||||||
|     const int il = 4 * tpitg.x;  // 0, 4, 8, 12
 | kernel void kernel_mul_mat_q3_K_f32( | ||||||
|  |         device const  void * src0, | ||||||
|  |         device const float * src1, | ||||||
|  |         device       float * dst, | ||||||
|  |         constant   int64_t & ne00, | ||||||
|  |         constant   int64_t & ne10, | ||||||
|  |         constant   int64_t & ne0, | ||||||
|  |         constant   int64_t & ne1, | ||||||
|  |         uint2 tgpig[[threadgroup_position_in_grid]], | ||||||
|  |         uint tiisg[[thread_index_in_simdgroup]], | ||||||
|  |         uint sgitg[[simdgroup_index_in_threadgroup]]) { | ||||||
|  | 
 | ||||||
|  |     const int nb = ne00/QK_K; | ||||||
|  | 
 | ||||||
|  |     const int64_t r0 = tgpig.x; | ||||||
|  |     const int64_t r1 = tgpig.y; | ||||||
|  | 
 | ||||||
|  |     const int row = 2 * r0 + sgitg; | ||||||
|  | 
 | ||||||
|  |     device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; | ||||||
|  |     device const float     * yy = (device const float      *) src1 + r1*ne10; | ||||||
|  |     const int ix = tiisg/4; | ||||||
|  |     const int il = 4 * (tiisg%4);// 0, 4, 8, 12
 | ||||||
|     const int im = il/8;         // 0, 0, 1, 1
 |     const int im = il/8;         // 0, 0, 1, 1
 | ||||||
|     const int in = il%8;         // 0, 4, 0, 4
 |     const int in = il%8;         // 0, 4, 0, 4
 | ||||||
| 
 | 
 | ||||||
|     float sumf = 0; |     float2 sum = {0.f, 0.f}; | ||||||
| 
 | 
 | ||||||
|     for (int i = tpitg.y; i < nb; i += tptg.y) { |     for (int i = ix; i < nb; i += 8) { | ||||||
| 
 | 
 | ||||||
|         const float d_all = (float)(x[i].d); |         const float d_all = (float)(x[i].d); | ||||||
| 
 | 
 | ||||||
|         device const uint8_t * q = x[i].qs + il; |         device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); | ||||||
|         device const uint8_t * h = x[i].hmask + in; |         device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); | ||||||
|         device const float   * y = yy + i * QK_K + il; |         device const uint16_t * s = (device const uint16_t *)(x[i].scales); | ||||||
|  |         device const float    * y = yy + i * QK_K + il; | ||||||
| 
 | 
 | ||||||
|         const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); |         const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8); | ||||||
|         const float d2 = d_all * ((x[i].scales[0] >>  4) - 8); |         const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f; | ||||||
|         const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); |         const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f; | ||||||
|         const float d4 = d_all * ((x[i].scales[1] >>  4) - 8); |         const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; | ||||||
| 
 | 
 | ||||||
|         for (int l = 0; l < 4; ++l) { |         for (int l = 0; l < 4; l += 2) { | ||||||
|             const uint8_t hm = h[l] >> im; |             const uint16_t hm = h[l/2] >> im; | ||||||
|             sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4)) |             sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 :  4)) | ||||||
|                   + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4)) |                     + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) | ||||||
|                   + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4)) |                     + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) | ||||||
|                   + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4)); |                     + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); | ||||||
|  |             sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)) | ||||||
|  |                     + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)) | ||||||
|  |                     + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)) | ||||||
|  |                     + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     } |     } | ||||||
|  |     const float sumf = sum[0] + sum[1] * 1.f/256.f; | ||||||
| 
 | 
 | ||||||
|     sum[ith] = sumf; |     const float tot = simd_sum(sumf); | ||||||
| 
 |     if (tiisg == 0) { | ||||||
| #endif |         dst[r1*ne0 + row] = tot; | ||||||
| 
 |  | ||||||
|     //
 |  | ||||||
|     // Accumulate the sum from all threads in the threadgroup
 |  | ||||||
|     //
 |  | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |  | ||||||
|     if (ith%4 == 0) { |  | ||||||
|         for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; |  | ||||||
|     } |  | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |  | ||||||
|     if (ith%16 == 0) { |  | ||||||
|         for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; |  | ||||||
|     } |  | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |  | ||||||
|     if (ith == 0) { |  | ||||||
|         for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; |  | ||||||
|         dst[r1*ne0 + r0] = sum[0]; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| } | } | ||||||
|  | #endif | ||||||
| 
 | 
 | ||||||
| #if QK_K == 256 | #if QK_K == 256 | ||||||
| kernel void kernel_mul_mat_q4_K_f32( | kernel void kernel_mul_mat_q4_K_f32( | ||||||
|  | @ -1773,7 +1814,6 @@ kernel void kernel_mul_mat_q5_K_f32( | ||||||
| 
 | 
 | ||||||
|     for (int i = ix; i < nb; i += 8) { |     for (int i = ix; i < nb; i += 8) { | ||||||
| 
 | 
 | ||||||
|         float4 sumy = {0.f, 0.f, 0.f, 0.f}; |  | ||||||
|         for (int l = 0; l < 4; ++l) { |         for (int l = 0; l < 4; ++l) { | ||||||
|             yl[l+0] = y[l+ 0]; |             yl[l+0] = y[l+ 0]; | ||||||
|             yl[l+4] = y[l+16]; |             yl[l+4] = y[l+16]; | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue