From 84b4f81ef184b56b078de3c8828fef437db843fd Mon Sep 17 00:00:00 2001 From: zhiyuan li Date: Fri, 27 Dec 2024 13:38:44 +0800 Subject: [PATCH] initial support for apple --- ggml/src/ggml-metal/ggml-metal.m | 54 +++++++++++++++++ ggml/src/ggml-metal/ggml-metal.metal | 86 ++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index f6c427f7f..a25d42925 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -182,6 +182,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, @@ -788,6 +789,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, has_simdgroup_reduction && use_bfloat); @@ -1249,6 +1251,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex return has_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + case GGML_OP_RWKV_WKV6: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -2152,6 +2155,57 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_RWKV_WKV6: + { + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[3]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[2]; + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64); // The current Metal kernel is designed for RWKV6, HEAD_SIZE == 64 + + size_t offs_k = 0; + size_t offs_v = 0; + size_t offs_r = 0; + size_t offs_tf = 0; + size_t offs_td = 0; + size_t offs_s = 0; + size_t offs_dst = 0; + + id id_k = dst->src[0] ? ggml_metal_get_buffer(dst->src[0], &offs_k) : nil; + id id_v = dst->src[1] ? ggml_metal_get_buffer(dst->src[1], &offs_v) : nil; + id id_r = dst->src[2] ? ggml_metal_get_buffer(dst->src[2], &offs_r) : nil; + id id_tf = dst->src[3] ? ggml_metal_get_buffer(dst->src[3], &offs_tf) : nil; + id id_td = dst->src[4] ? ggml_metal_get_buffer(dst->src[4], &offs_td) : nil; + id id_s = dst->src[5] ? ggml_metal_get_buffer(dst->src[5], &offs_s) : nil; + id id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32].pipeline; + + id command_buffer = ctx->queue.commandBuffer; + id encoder = [command_buffer computeCommandEncoder]; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_k offset:offs_k atIndex:0]; + [encoder setBuffer:id_v offset:offs_v atIndex:1]; + [encoder setBuffer:id_r offset:offs_r atIndex:2]; + [encoder setBuffer:id_tf offset:offs_tf atIndex:3]; + [encoder setBuffer:id_td offset:offs_td atIndex:4]; + [encoder setBuffer:id_s offset:offs_s atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; + + [encoder setBytes:&B length:sizeof(B) atIndex:7]; + [encoder setBytes:&T length:sizeof(T) atIndex:8]; + [encoder setBytes:&C length:sizeof(C) atIndex:9]; + [encoder setBytes:&H length:sizeof(H) atIndex:10]; + + [encoder dispatchThreadgroups:MTLSizeMake(B * H, 1, 1) threadsPerThreadgroup:MTLSizeMake(C/ H, 1, 1)]; + + [encoder endEncoding]; + [command_buffer commit]; + } break; case GGML_OP_MUL_MAT: { GGML_ASSERT(ne00 == ne10); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f394d743c..e75f8aecf 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1366,6 +1366,92 @@ kernel void kernel_ssm_scan_f32( } } +kernel void kernel_rwkv_wkv6_f32( + device const float * k, + device const float * v, + device const float * r, + device const float * tf, + device const float * td, + device const float * state_in, + device float * dst, + constant uint & B, + constant uint & T, + constant uint & C, + constant uint & H, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const uint head_size = 64; // rwkv6 + const uint batch_id = tgpig.x / H; + const uint head_id = tgpig.x % H; + const uint tid = tpitg.x; + + if (batch_id >= B || head_id >= H) { + return; + } + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + threadgroup float _k[head_size]; + threadgroup float _r[head_size]; + threadgroup float _tf[head_size]; + threadgroup float _td[head_size]; + + float state[head_size]; + #pragma unroll(64) + for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + _tf[tid] = tf[head_id * head_size + tid]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + threadgroup_barrier(mem_flags::mem_threadgroup); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float v_val = v[t]; + float y = 0.0; + + #pragma unroll(64) + for (uint j = 0; j < head_size; j += 4) { + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + + float4 kv = k_vec * v_val; + + float4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + #pragma unroll(64) + for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} + kernel void kernel_argmax( device const void * x, device int32_t * dst,