metal : Q3_K speedup (#2995)
* Slightly faster Q3_K and Q5_K on metal * Another Q3_K speedup on metal Combined with previous commit, we are now +9.6% for TG. PP is not affected as this happens via the matrix multiplication templates. * Slowly progressing on Q3_K on metal We are now 13% faster than master * nother small improvement for Q3_K on metal --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
		
							parent
							
								
									e64f5b5578
								
							
						
					
					
						commit
						ba7ffbb251
					
				
					 1 changed files with 89 additions and 46 deletions
				
			
		
							
								
								
									
										129
									
								
								ggml-metal.metal
									
										
									
									
									
								
							
							
						
						
									
										129
									
								
								ggml-metal.metal
									
										
									
									
									
								
							|  | @ -1123,31 +1123,40 @@ kernel void kernel_mul_mat_q3_K_f32( | |||
|     device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; | ||||
|     device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1; | ||||
| 
 | ||||
|     float yl[16]; | ||||
|     float yl[32]; | ||||
| 
 | ||||
|     const uint16_t kmask1 = 0x0303; | ||||
|     const uint16_t kmask1 = 0x3030; | ||||
|     const uint16_t kmask2 = 0x0f0f; | ||||
| 
 | ||||
|     const int tid = tiisg/2; | ||||
|     const int ix  = tiisg%2; | ||||
|     const int ip  = tid/8;          // 0 or 1
 | ||||
|     const int il  = tid/2 - 4*ip;   // 0...3
 | ||||
|     const int tid = tiisg/4; | ||||
|     const int ix  = tiisg%4; | ||||
|     const int ip  = tid/4;          // 0 or 1
 | ||||
|     const int il  = 2*((tid%4)/2);  // 0 or 2
 | ||||
|     const int ir  = tid%2; | ||||
|     const int n   = 8; | ||||
|     const int l0  = n*ir; | ||||
| 
 | ||||
|     const uint16_t m1 = 1 << (4*ip + il); | ||||
|     const uint16_t m2 = m1 << 8; | ||||
|     // One would think that the Metal compiler would figure out that ip and il can only have
 | ||||
|     // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
 | ||||
|     // with these two tales.
 | ||||
|     //
 | ||||
|     // Possible masks for the high bit
 | ||||
|     const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
 | ||||
|                            {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
 | ||||
|                            {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
 | ||||
|                            {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
 | ||||
| 
 | ||||
|     // Possible masks for the low 2 bits
 | ||||
|     const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; | ||||
| 
 | ||||
|     const ushort4 hm = mm[2*ip + il/2]; | ||||
| 
 | ||||
|     const int shift = 2*il; | ||||
|     const uint16_t qm1 = 0x0003 << shift; | ||||
|     const uint16_t qm2 = 0x0300 << shift; | ||||
|     const int32_t v1 = 4 << shift; | ||||
|     const int32_t v2 = 1024 << shift; | ||||
|     const float    v1 = il == 0 ? 4.f : 64.f; | ||||
|     const float    v2 = 4.f * v1; | ||||
| 
 | ||||
|     const uint16_t s_shift1 = 4*ip; | ||||
|     const uint16_t s_shift2 = s_shift1 + 2*(il/2); | ||||
|     const int ik = 4 + (il%2); | ||||
|     const uint16_t s_shift2 = s_shift1 + il; | ||||
| 
 | ||||
|     const int q_offset = 32*ip + l0; | ||||
|     const int y_offset = 128*ip + 32*il + l0; | ||||
|  | @ -1156,12 +1165,19 @@ kernel void kernel_mul_mat_q3_K_f32( | |||
| 
 | ||||
|     device const float * y1 = yy + ix*QK_K + y_offset; | ||||
| 
 | ||||
|     float sumf1[2] = {0.f}, sumf2[2] = {0.f}; | ||||
|     for (int i = ix; i < nb; i += 2) { | ||||
|     uint32_t scales32, aux32; | ||||
|     thread uint16_t * scales16 = (thread uint16_t *)&scales32; | ||||
|     thread const int8_t * scales = (thread const int8_t *)&scales32; | ||||
| 
 | ||||
|     float sumf1[2] = {0.f}; | ||||
|     float sumf2[2] = {0.f}; | ||||
|     for (int i = ix; i < nb; i += 4) { | ||||
| 
 | ||||
|         for (int l = 0; l < 8; ++l) { | ||||
|             yl[l+ 0] = y1[l+ 0]; | ||||
|             yl[l+ 8] = y1[l+16]; | ||||
|             yl[l+16] = y1[l+32]; | ||||
|             yl[l+24] = y1[l+48]; | ||||
|         } | ||||
| 
 | ||||
|         device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); | ||||
|  | @ -1172,27 +1188,43 @@ kernel void kernel_mul_mat_q3_K_f32( | |||
|         for (int row = 0; row < 2; ++row) { | ||||
| 
 | ||||
|             const float d_all = (float)dh[0]; | ||||
|             const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); | ||||
| 
 | ||||
|             float s1 = 0, s2 = 0; | ||||
|             for (int l = 0; l < n; l += 2) { | ||||
|                 const uint16_t qs = q[l/2]; | ||||
|                 s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); | ||||
|                 s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); | ||||
|             } | ||||
|             float d = d_all * (s1 + 1.f/256.f * s2); | ||||
|             sumf1[row] += d * scales[0]; | ||||
|             sumf2[row] += d; | ||||
|             scales16[0] = a[4]; | ||||
|             scales16[1] = a[5]; | ||||
|             aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; | ||||
|             scales16[0] = a[il+0]; | ||||
|             scales16[1] = a[il+1]; | ||||
|             scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; | ||||
| 
 | ||||
|             s1 = s2 = 0; | ||||
|             float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; | ||||
|             for (int l = 0; l < n; l += 2) { | ||||
|                 const uint16_t qs = q[l/2+8]; | ||||
|                 s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); | ||||
|                 s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); | ||||
|                 const int32_t qs = q[l/2]; | ||||
|                 s1 += yl[l+0] * (qs & qm[il/2][0]); | ||||
|                 s2 += yl[l+1] * (qs & qm[il/2][1]); | ||||
|                 s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); | ||||
|                 s4 += yl[l+16] * (qs & qm[il/2][2]); | ||||
|                 s5 += yl[l+17] * (qs & qm[il/2][3]); | ||||
|                 s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); | ||||
|             } | ||||
|             d = d_all * (s1 + 1.f/256.f * s2); | ||||
|             sumf1[row] += d * scales[1]; | ||||
|             sumf2[row] += d; | ||||
|             float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); | ||||
|             float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); | ||||
|             sumf1[row] += d1 * (scales[0] - 32); | ||||
|             sumf2[row] += d2 * (scales[2] - 32); | ||||
| 
 | ||||
|             s1 = s2 = s3 = s4 = s5 = s6 = 0; | ||||
|             for (int l = 0; l < n; l += 2) { | ||||
|                 const int32_t qs = q[l/2+8]; | ||||
|                 s1 += yl[l+8] * (qs & qm[il/2][0]); | ||||
|                 s2 += yl[l+9] * (qs & qm[il/2][1]); | ||||
|                 s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); | ||||
|                 s4 += yl[l+24] * (qs & qm[il/2][2]); | ||||
|                 s5 += yl[l+25] * (qs & qm[il/2][3]); | ||||
|                 s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); | ||||
|             } | ||||
|             d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); | ||||
|             d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); | ||||
|             sumf1[row] += d1 * (scales[1] - 32); | ||||
|             sumf2[row] += d2 * (scales[3] - 32); | ||||
| 
 | ||||
|             q  += step; | ||||
|             h  += step; | ||||
|  | @ -1201,17 +1233,20 @@ kernel void kernel_mul_mat_q3_K_f32( | |||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         y1 += 2 * QK_K; | ||||
|         y1 += 4 * QK_K; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     for (int row = 0; row < 2; ++row) { | ||||
|         const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); | ||||
|         const float tot = simd_sum(sumf); | ||||
|         const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); | ||||
|         sumf1[row] = simd_sum(sumf); | ||||
|     } | ||||
|     if (tiisg == 0) { | ||||
|             dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; | ||||
|         for (int row = 0; row < 2; ++row) { | ||||
|             dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| #else | ||||
| kernel void kernel_mul_mat_q3_K_f32( | ||||
|  | @ -1564,17 +1599,25 @@ kernel void kernel_mul_mat_q5_K_f32( | |||
|             sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); | ||||
|             sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); | ||||
| 
 | ||||
|             float4 acc = {0.f, 0.f, 0.f, 0.f}; | ||||
|             float4 acc1 = {0.f}; | ||||
|             float4 acc2 = {0.f}; | ||||
|             for (int l = 0; l < n; ++l) { | ||||
|                 uint8_t h = qh[l]; | ||||
|                 acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); | ||||
|                 acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); | ||||
|                 acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); | ||||
|                 acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); | ||||
|                 acc1[0] += yl[l+0] * (q1[l] & 0x0F); | ||||
|                 acc1[1] += yl[l+8] * (q1[l] & 0xF0); | ||||
|                 acc1[2] += yh[l+0] * (q2[l] & 0x0F); | ||||
|                 acc1[3] += yh[l+8] * (q2[l] & 0xF0); | ||||
|                 acc2[0] += h & hm1 ? yl[l+0] : 0.f; | ||||
|                 acc2[1] += h & hm2 ? yl[l+8] : 0.f; | ||||
|                 acc2[2] += h & hm3 ? yh[l+0] : 0.f; | ||||
|                 acc2[3] += h & hm4 ? yh[l+8] : 0.f; | ||||
|             } | ||||
|             const float dall = dh[0]; | ||||
|             const float dmin = dh[1]; | ||||
|             sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - | ||||
|             sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) + | ||||
|                                  sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + | ||||
|                                  sc8[4] * (acc1[2] +  16.f*acc2[2]) + | ||||
|                                  sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - | ||||
|                          dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); | ||||
| 
 | ||||
|             q1 += step; | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue