Fix metal wkv6 inference
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
65307d279f
commit
d564c4b534
1 changed files with 15 additions and 29 deletions
|
@ -2158,42 +2158,31 @@ static void ggml_metal_encode_node(
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[3];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[2];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == 64); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
GGML_ASSERT(C / H == 64);
|
||||
|
||||
size_t offs_k = 0;
|
||||
size_t offs_v = 0;
|
||||
size_t offs_r = 0;
|
||||
size_t offs_tf = 0;
|
||||
size_t offs_td = 0;
|
||||
size_t offs_s = 0;
|
||||
size_t offs_dst = 0;
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
|
||||
id<MTLBuffer> id_k = dst->src[0] ? ggml_metal_get_buffer(dst->src[0], &offs_k) : nil;
|
||||
id<MTLBuffer> id_v = dst->src[1] ? ggml_metal_get_buffer(dst->src[1], &offs_v) : nil;
|
||||
id<MTLBuffer> id_r = dst->src[2] ? ggml_metal_get_buffer(dst->src[2], &offs_r) : nil;
|
||||
id<MTLBuffer> id_tf = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_tf) : nil;
|
||||
id<MTLBuffer> id_td = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_td) : nil;
|
||||
id<MTLBuffer> id_s = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_s) : nil;
|
||||
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
||||
id<MTLBuffer> id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline;
|
||||
|
||||
id<MTLCommandBuffer> command_buffer = ctx->queue.commandBuffer;
|
||||
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_k offset:offs_k atIndex:0];
|
||||
[encoder setBuffer:id_v offset:offs_v atIndex:1];
|
||||
[encoder setBuffer:id_r offset:offs_r atIndex:2];
|
||||
[encoder setBuffer:id_tf offset:offs_tf atIndex:3];
|
||||
[encoder setBuffer:id_td offset:offs_td atIndex:4];
|
||||
[encoder setBuffer:id_s offset:offs_s atIndex:5];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
[encoder setBytes:&B length:sizeof(B) atIndex:7];
|
||||
|
@ -2202,9 +2191,6 @@ static void ggml_metal_encode_node(
|
|||
[encoder setBytes:&H length:sizeof(H) atIndex:10];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)];
|
||||
|
||||
[encoder endEncoding];
|
||||
[command_buffer commit];
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue