ggml : add ggml_soft_max_ext (#4256)
* metal : implement soft_max_ext * cuda : implement soft_max_ext * ggml : implement soft_max_ext (CPU) * batched-bench : print threads ggml-ci * metal : simplify soft_max encoding ggml-ci * cuda : use 512 threads for soft_max instead of 32 * ggml : update soft max cpu * cuda : do warp-based block reduce * cuda : increase max block size to 1024 * cuda : fix warp reduction initialization of shared mem * metal : warp-based reduction for soft max kernel * metal : warp-based reduce for rms_norm * metal : simplify soft max kernel ggml-ci * alloc : fix build with debug
This commit is contained in:
		
							parent
							
								
									1d144112c0
								
							
						
					
					
						commit
						ef47ec18da
					
				
					 8 changed files with 311 additions and 196 deletions
				
			
		
							
								
								
									
										43
									
								
								ggml-metal.m
									
										
									
									
									
								
							
							
						
						
									
										43
									
								
								ggml-metal.m
									
										
									
									
									
								
							|  | @ -1028,20 +1028,27 @@ void ggml_metal_graph_compute( | |||
|                             int nth = 32; // SIMD width | ||||
| 
 | ||||
|                             if (ne00%4 == 0) { | ||||
|                                 while (nth < ne00/4 && nth < 256) { | ||||
|                                     nth *= 2; | ||||
|                                 } | ||||
|                                 [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; | ||||
|                             } else { | ||||
|                                 do { | ||||
|                                 while (nth < ne00 && nth < 1024) { | ||||
|                                     nth *= 2; | ||||
|                                 } while (nth <= ne00 && nth <= 1024); | ||||
|                                 nth /= 2; | ||||
|                                 } | ||||
|                                 [encoder setComputePipelineState:ctx->pipeline_soft_max]; | ||||
|                             } | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1]; | ||||
|                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; | ||||
|                             [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; | ||||
|                             [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; | ||||
|                             [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; | ||||
| 
 | ||||
|                             const float scale = ((float *) dst->op_params)[0]; | ||||
| 
 | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0]; | ||||
|                             [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1]; | ||||
|                             [encoder setBuffer:id_dst  offset:offs_dst    atIndex:2]; | ||||
|                             [encoder setBytes:&ne00  length:sizeof(ne00)  atIndex:3]; | ||||
|                             [encoder setBytes:&ne01  length:sizeof(ne01)  atIndex:4]; | ||||
|                             [encoder setBytes:&ne02  length:sizeof(ne02)  atIndex:5]; | ||||
|                             [encoder setBytes:&scale length:sizeof(scale) atIndex:6]; | ||||
|                             [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; | ||||
| 
 | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||
|                         } break; | ||||
|  | @ -1351,15 +1358,19 @@ void ggml_metal_graph_compute( | |||
|                             float eps; | ||||
|                             memcpy(&eps, dst->op_params, sizeof(float)); | ||||
| 
 | ||||
|                             const int nth = MIN(512, ne00); | ||||
|                             int nth = 32; // SIMD width | ||||
| 
 | ||||
|                             while (nth < ne00/4 && nth < 1024) { | ||||
|                                 nth *= 2; | ||||
|                             } | ||||
| 
 | ||||
|                             [encoder setComputePipelineState:ctx->pipeline_rms_norm]; | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1]; | ||||
|                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; | ||||
|                             [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; | ||||
|                             [encoder setBytes:&eps  length:sizeof(   float) atIndex:4]; | ||||
|                             [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0]; | ||||
|                             [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1]; | ||||
|                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2]; | ||||
|                             [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3]; | ||||
|                             [encoder setBytes:&eps     length:sizeof(   float) atIndex:4]; | ||||
|                             [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; | ||||
| 
 | ||||
|                             const int64_t nrows = ggml_nrows(src0); | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue