Another very minor speedup on metal
This commit is contained in:
		
							parent
							
								
									2cb47e0e16
								
							
						
					
					
						commit
						e3ff8c20c8
					
				
					 1 changed files with 37 additions and 29 deletions
				
			
		|  | @ -133,19 +133,24 @@ kernel void kernel_soft_max( | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // broadcast
 |     //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
 | ||||||
|     if (tpitg[0] == 0) { |     //               the loop, and when that is done, buf[0] has the correct (synchronized) value
 | ||||||
|         buf[0] = buf[0]; |     //if (tpitg[0] == 0) {
 | ||||||
|     } |     //    buf[0] = buf[0];
 | ||||||
|  |     //}
 | ||||||
| 
 | 
 | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |     //threadgroup_barrier(mem_flags::mem_threadgroup);
 | ||||||
| 
 | 
 | ||||||
|     const float max = buf[0]; |     const float max = buf[0]; | ||||||
| 
 | 
 | ||||||
|     // parallel sum
 |     // parallel sum
 | ||||||
|     buf[tpitg[0]] = 0.0f; |     buf[tpitg[0]] = 0.0f; | ||||||
|     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { |     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { | ||||||
|         buf[tpitg[0]] += exp(psrc0[i00] - max); |         const float exp_psrc0 = exp(psrc0[i00] - max); | ||||||
|  |         buf[tpitg[0]] += exp_psrc0; | ||||||
|  |         // Remember the result of exp here. exp is expensive, so we really do not
 | ||||||
|  |         // whish to compute it twice.
 | ||||||
|  |         pdst[i00] = exp_psrc0; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // reduce
 |     // reduce
 | ||||||
|  | @ -157,17 +162,18 @@ kernel void kernel_soft_max( | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // broadcast
 |     // broadcast - not needed, see above
 | ||||||
|     if (tpitg[0] == 0) { |     //// broadcast
 | ||||||
|         buf[0] = buf[0]; |     //if (tpitg[0] == 0) {
 | ||||||
|     } |     //    buf[0] = buf[0];
 | ||||||
|  |     //}
 | ||||||
| 
 | 
 | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |     //threadgroup_barrier(mem_flags::mem_threadgroup);
 | ||||||
| 
 | 
 | ||||||
|     const float sum = buf[0]; |     const float sum = buf[0]; | ||||||
| 
 | 
 | ||||||
|     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { |     for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { | ||||||
|         pdst[i00] = exp(psrc0[i00] - max) / sum; |         pdst[i00] /= sum; | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -214,25 +220,27 @@ kernel void kernel_norm( | ||||||
|         } |         } | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|     } |     } | ||||||
|     // broadcast
 |     //// broadcast
 | ||||||
|     if (tpitg == 0) { |     //if (tpitg == 0) {
 | ||||||
|         sum[0] /= ne00; |     //    sum[0] /= ne00;
 | ||||||
|     } |     //}
 | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |     //threadgroup_barrier(mem_flags::mem_threadgroup);
 | ||||||
|     const float mean  = sum[0]; |     const float mean  = sum[0]; | ||||||
| 
 | 
 | ||||||
|     // recenter
 |     // recenter and VARIANCE
 | ||||||
|     device float * y = dst + tgpig*ne00; |     device float * y = dst + tgpig*ne00; | ||||||
|     for (int i00 = tpitg; i00 < ne00; i00 += ntg) { |  | ||||||
|         y[i00] = x[i00] - mean; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // VARIANCE
 |  | ||||||
|     // parallel sum
 |  | ||||||
|     sum[tpitg] = 0.0f; |     sum[tpitg] = 0.0f; | ||||||
|     for (int i00 = tpitg; i00 < ne00; i00 += ntg) { |     for (int i00 = tpitg; i00 < ne00; i00 += ntg) { | ||||||
|  |         y[i00] = x[i00] - mean; | ||||||
|         sum[tpitg] += y[i00] * y[i00]; |         sum[tpitg] += y[i00] * y[i00]; | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     //// VARIANCE
 | ||||||
|  |     //// parallel sum
 | ||||||
|  |     //sum[tpitg] = 0.0f;
 | ||||||
|  |     //for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
 | ||||||
|  |     //    sum[tpitg] += y[i00] * y[i00];
 | ||||||
|  |     //}
 | ||||||
|     // reduce
 |     // reduce
 | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|     for (uint i = ntg/2; i > 0; i /= 2) { |     for (uint i = ntg/2; i > 0; i /= 2) { | ||||||
|  | @ -241,11 +249,11 @@ kernel void kernel_norm( | ||||||
|         } |         } | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|     } |     } | ||||||
|     // broadcast
 |     //// broadcast
 | ||||||
|     if (tpitg == 0) { |     //if (tpitg == 0) {
 | ||||||
|         sum[0] /= ne00; |     //    sum[0] /= ne00;
 | ||||||
|     } |     //}
 | ||||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); |     //threadgroup_barrier(mem_flags::mem_threadgroup);
 | ||||||
|     const float variance = sum[0]; |     const float variance = sum[0]; | ||||||
| 
 | 
 | ||||||
|     const float scale = 1.0f/sqrt(variance + eps); |     const float scale = 1.0f/sqrt(variance + eps); | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue