Very minor speedup via simd-group synchronization in f16 x f32
This commit is contained in:
		
							parent
							
								
									69fdbb9abc
								
							
						
					
					
						commit
						2cb47e0e16
					
				
					 2 changed files with 8 additions and 37 deletions
				
			
		|  | @ -971,7 +971,7 @@ void ggml_metal_graph_compute( | |||
|                                 else if (src0t == GGML_TYPE_Q6_K) { | ||||
|                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                                 } else { | ||||
|                                     [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; | ||||
|                                     //[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; | ||||
|                                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                                 } | ||||
|                             } | ||||
|  |  | |||
|  | @ -515,11 +515,8 @@ kernel void kernel_mul_mat_f16_f32( | |||
|         constant  uint64_t & nb12, | ||||
|         constant   int64_t & ne0, | ||||
|         constant   int64_t & ne1, | ||||
|         threadgroup float  * sum [[threadgroup(0)]], | ||||
|         uint3 tgpig[[threadgroup_position_in_grid]], | ||||
|         uint3  tpig[[thread_position_in_grid]], | ||||
|         uint3 tpitg[[thread_position_in_threadgroup]], | ||||
|         uint3  tptg[[threads_per_threadgroup]]) { | ||||
|         uint tiisg[[thread_index_in_simdgroup]]) { | ||||
| 
 | ||||
|     const int64_t r0 = tgpig.x; | ||||
|     const int64_t r1 = tgpig.y; | ||||
|  | @ -528,42 +525,16 @@ kernel void kernel_mul_mat_f16_f32( | |||
|     device const half  * x = (device const half  *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); | ||||
|     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); | ||||
| 
 | ||||
|     uint ith = tpitg.x; | ||||
|     uint nth = tptg.x; | ||||
| 
 | ||||
|     sum[ith] = 0.0f; | ||||
| 
 | ||||
|     for (int i = ith; i < ne00; i += nth) { | ||||
|         sum[ith] += (float) x[i] * (float) y[i]; | ||||
|     float sumf = 0; | ||||
|     for (int i = tiisg; i < ne00; i += 32) { | ||||
|         sumf += (float) x[i] * (float) y[i]; | ||||
|     } | ||||
| 
 | ||||
|     // accumulate the sum from all threads in the threadgroup
 | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     if (ith%4 == 0) { | ||||
|         for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; | ||||
|     } | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     if (ith%16 == 0) { | ||||
|         for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; | ||||
|     } | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     if (ith == 0) { | ||||
|         for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; | ||||
|         dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; | ||||
|     float all_sum = simd_sum(sumf); | ||||
|     if (tiisg == 0) { | ||||
|         dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | ||||
|     } | ||||
| 
 | ||||
|     // Original implementation. Left behind commented out for now
 | ||||
|     //threadgroup_barrier(mem_flags::mem_threadgroup);
 | ||||
|     //for (uint i = tptg.x/2; i > 0; i /= 2) {
 | ||||
|     //    if (tpitg.x < i) {
 | ||||
|     //        sum[tpitg.x] += sum[tpitg.x + i];
 | ||||
|     //    }
 | ||||
|     //    threadgroup_barrier(mem_flags::mem_threadgroup);
 | ||||
|     //}
 | ||||
|     //
 | ||||
|     //if (tpitg.x == 0) {
 | ||||
|     //    dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
 | ||||
|     //}
 | ||||
| } | ||||
| 
 | ||||
| kernel void kernel_alibi_f32( | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue