Fix metal wkv6 inference
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
		
							parent
							
								
									65307d279f
								
							
						
					
					
						commit
						d564c4b534
					
				
					 1 changed files with 15 additions and 29 deletions
				
			
		|  | @ -2158,42 +2158,31 @@ static void ggml_metal_encode_node( | |||
|         case GGML_OP_RWKV_WKV6: | ||||
|             { | ||||
|                 const int64_t B = dst->src[5]->ne[1]; | ||||
|                 const int64_t T = dst->src[0]->ne[3]; | ||||
|                 const int64_t T = dst->src[0]->ne[2]; | ||||
|                 const int64_t C = dst->ne[0]; | ||||
|                 const int64_t H = dst->src[0]->ne[2]; | ||||
|                 const int64_t H = dst->src[0]->ne[1]; | ||||
| 
 | ||||
|                 GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); | ||||
|                 GGML_ASSERT(C % H == 0); | ||||
|                 GGML_ASSERT(C / H == 64); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64 | ||||
|                 GGML_ASSERT(C / H == 64); | ||||
| 
 | ||||
|                 size_t offs_k = 0; | ||||
|                 size_t offs_v = 0; | ||||
|                 size_t offs_r = 0; | ||||
|                 size_t offs_tf = 0; | ||||
|                 size_t offs_td = 0; | ||||
|                 size_t offs_s = 0; | ||||
|                 size_t offs_dst  = 0; | ||||
|                 size_t offs_src3 = 0; | ||||
|                 size_t offs_src4 = 0; | ||||
|                 size_t offs_src5 = 0; | ||||
| 
 | ||||
|                 id<MTLBuffer> id_k = dst->src[0] ? ggml_metal_get_buffer(dst->src[0], &offs_k) : nil; | ||||
|                 id<MTLBuffer> id_v = dst->src[1] ? ggml_metal_get_buffer(dst->src[1], &offs_v) : nil; | ||||
|                 id<MTLBuffer> id_r = dst->src[2] ? ggml_metal_get_buffer(dst->src[2], &offs_r) : nil; | ||||
|                 id<MTLBuffer> id_tf = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_tf) : nil; | ||||
|                 id<MTLBuffer> id_td = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_td) : nil; | ||||
|                 id<MTLBuffer> id_s = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_s) : nil; | ||||
|                 id<MTLBuffer> id_dst  = dst         ? ggml_metal_get_buffer(dst,         &offs_dst)  : nil; | ||||
|                 id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; | ||||
|                 id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; | ||||
|                 id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; | ||||
| 
 | ||||
|                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline; | ||||
| 
 | ||||
|                 id<MTLCommandBuffer> command_buffer = ctx->queue.commandBuffer; | ||||
|                 id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder]; | ||||
| 
 | ||||
|                 [encoder setComputePipelineState:pipeline]; | ||||
|                 [encoder setBuffer:id_k offset:offs_k atIndex:0]; | ||||
|                 [encoder setBuffer:id_v offset:offs_v atIndex:1]; | ||||
|                 [encoder setBuffer:id_r offset:offs_r atIndex:2]; | ||||
|                 [encoder setBuffer:id_tf offset:offs_tf atIndex:3]; | ||||
|                 [encoder setBuffer:id_td offset:offs_td atIndex:4]; | ||||
|                 [encoder setBuffer:id_s offset:offs_s atIndex:5]; | ||||
|                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||
|                 [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; | ||||
|                 [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; | ||||
|                 [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; | ||||
|                 [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; | ||||
|                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6]; | ||||
| 
 | ||||
|                 [encoder setBytes:&B length:sizeof(B) atIndex:7]; | ||||
|  | @ -2202,9 +2191,6 @@ static void ggml_metal_encode_node( | |||
|                 [encoder setBytes:&H length:sizeof(H) atIndex:10]; | ||||
| 
 | ||||
|                 [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; | ||||
| 
 | ||||
|                 [encoder endEncoding]; | ||||
|                 [command_buffer commit]; | ||||
|             } break; | ||||
|         case GGML_OP_MUL_MAT: | ||||
|             { | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue