diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index a25d42925..246bfdecb 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2158,42 +2158,31 @@ static void ggml_metal_encode_node( case GGML_OP_RWKV_WKV6: { const int64_t B = dst->src[5]->ne[1]; - const int64_t T = dst->src[0]->ne[3]; + const int64_t T = dst->src[0]->ne[2]; const int64_t C = dst->ne[0]; - const int64_t H = dst->src[0]->ne[2]; + const int64_t H = dst->src[0]->ne[1]; GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); GGML_ASSERT(C % H == 0); - GGML_ASSERT(C / H == 64); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64 + GGML_ASSERT(C / H == 64); - size_t offs_k = 0; - size_t offs_v = 0; - size_t offs_r = 0; - size_t offs_tf = 0; - size_t offs_td = 0; - size_t offs_s = 0; - size_t offs_dst = 0; + size_t offs_src3 = 0; + size_t offs_src4 = 0; + size_t offs_src5 = 0; - id id_k = dst->src[0] ? ggml_metal_get_buffer(dst->src[0], &offs_k) : nil; - id id_v = dst->src[1] ? ggml_metal_get_buffer(dst->src[1], &offs_v) : nil; - id id_r = dst->src[2] ? ggml_metal_get_buffer(dst->src[2], &offs_r) : nil; - id id_tf = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_tf) : nil; - id id_td = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_td) : nil; - id id_s = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_s) : nil; - id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + id id_src3 = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_src3) : nil; + id id_src4 = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_src4) : nil; + id id_src5 = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_src5) : nil; id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline; - id command_buffer = ctx->queue.commandBuffer; - id encoder = [command_buffer computeCommandEncoder]; - [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_k offset:offs_k atIndex:0]; - [encoder setBuffer:id_v offset:offs_v atIndex:1]; - [encoder setBuffer:id_r offset:offs_r atIndex:2]; - [encoder setBuffer:id_tf offset:offs_tf atIndex:3]; - [encoder setBuffer:id_td offset:offs_td atIndex:4]; - [encoder setBuffer:id_s offset:offs_s atIndex:5]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; [encoder setBytes:&B length:sizeof(B) atIndex:7]; @@ -2202,9 +2191,6 @@ static void ggml_metal_encode_node( [encoder setBytes:&H length:sizeof(H) atIndex:10]; [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; - - [encoder endEncoding]; - [command_buffer commit]; } break; case GGML_OP_MUL_MAT: {