From 2d473a4a9a62c3c11841a8f2274376e78ee727a6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 16 May 2024 12:03:53 +0300 Subject: [PATCH] metal : support rope freq_factors --- ggml-cuda/rope.cu | 26 +++++---- ggml-metal.m | 131 +++++++++++++++++++++++++++------------------- ggml-metal.metal | 6 ++- ggml.c | 13 +++-- 4 files changed, 108 insertions(+), 68 deletions(-) diff --git a/ggml-cuda/rope.cu b/ggml-cuda/rope.cu index fee56a1f2..8d11d9d14 100644 --- a/ggml-cuda/rope.cu +++ b/ggml-cuda/rope.cu @@ -249,8 +249,11 @@ static void rope_neox_cuda_f32( void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); @@ -278,23 +281,28 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const float* freq_factors = nullptr; + const float * freq_factors = nullptr; const int32_t * pos = nullptr; - if ((mode & 1) == 0) { + + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + if (is_neox) { + // TODO: move these asserts to ggml.c GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->ne[0] == ne2); pos = (const int32_t *) src1_d; - if (dst->src[2] != nullptr) { - GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2); - freq_factors = (const float*) dst->src[2]->data; + if (src2 != nullptr) { + // TODO: move these asserts to ggml.c + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float*) src2->data; } + } else { + GGML_ASSERT(src2 == nullptr && "TODO: freq_factors not implemented for mode 1"); } - const bool is_neox = mode & 2; - const bool is_glm = mode & 4; - rope_corr_dims corr_dims; ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v); diff --git a/ggml-metal.m b/ggml-metal.m index b0b16dbf7..7bc75f39b 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -927,22 +927,32 @@ static enum ggml_status ggml_metal_graph_compute( const int64_t ne10 = src1 ? src1->ne[0] : 0; const int64_t ne11 = src1 ? src1->ne[1] : 0; const int64_t ne12 = src1 ? src1->ne[2] : 0; - const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13); + const int64_t ne13 = src1 ? src1->ne[3] : 0; const uint64_t nb10 = src1 ? src1->nb[0] : 0; const uint64_t nb11 = src1 ? src1->nb[1] : 0; const uint64_t nb12 = src1 ? src1->nb[2] : 0; - const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); + const uint64_t nb13 = src1 ? src1->nb[3] : 0; - const int64_t ne0 = dst ? dst->ne[0] : 0; - const int64_t ne1 = dst ? dst->ne[1] : 0; - const int64_t ne2 = dst ? dst->ne[2] : 0; - const int64_t ne3 = dst ? dst->ne[3] : 0; + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; GGML_UNUSED(ne22); + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); - const uint64_t nb0 = dst ? dst->nb[0] : 0; - const uint64_t nb1 = dst ? dst->nb[1] : 0; - const uint64_t nb2 = dst ? dst->nb[2] : 0; - const uint64_t nb3 = dst ? dst->nb[3] : 0; + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; + + const int64_t ne0 = dst ? dst->ne[0] : 0; + const int64_t ne1 = dst ? dst->ne[1] : 0; + const int64_t ne2 = dst ? dst->ne[2] : 0; + const int64_t ne3 = dst ? dst->ne[3] : 0; + + const uint64_t nb0 = dst ? dst->nb[0] : 0; + const uint64_t nb1 = dst ? dst->nb[1] : 0; + const uint64_t nb2 = dst ? dst->nb[2] : 0; + const uint64_t nb3 = dst ? dst->nb[3] : 0; const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; @@ -1785,16 +1795,6 @@ static enum ggml_status ggml_metal_graph_compute( const int n_as = src0->ne[2]; // src2 = ids - const int64_t ne20 = src2->ne[0]; - const int64_t ne21 = src2->ne[1]; - const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22); - const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23); - - const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20); - const uint64_t nb21 = src2->nb[1]; - const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22); - const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23); - const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); GGML_ASSERT(src2t == GGML_TYPE_I32); @@ -2244,7 +2244,13 @@ static enum ggml_status ggml_metal_graph_compute( // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -2252,6 +2258,25 @@ static enum ggml_status ggml_metal_graph_compute( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + const bool is_neox = mode & 2; + const bool is_glm = mode & 4; + + GGML_ASSERT(!is_glm && "GLM RoPE not implemented in Metal"); + + if (is_neox) { + // TODO: move these asserts to ggml.c + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src1->ne[0] == ne2); + + if (id_src2 != nil) { + // TODO: move these asserts to ggml.c + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + } + } else { + GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for mode 1"); + } + id pipeline = nil; switch (src0->type) { @@ -2263,33 +2288,38 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; - [encoder setBytes:&mode length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + if (id_src2 != nil) { + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + } else { + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + } + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; + [encoder setBytes:&mode length:sizeof( int) atIndex:22]; + [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:24]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -2535,11 +2565,6 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); - const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); - const uint64_t nb21 = src2 ? src2->nb[1] : 0; - const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; - const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); //const int64_t ne31 = src3 ? src3->ne[1] : 0; const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); diff --git a/ggml-metal.metal b/ggml-metal.metal index 386e9195f..44e54ced2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1640,6 +1640,7 @@ static void rope_yarn_corr_dims( typedef void (rope_t)( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1675,6 +1676,7 @@ template kernel void kernel_rope( device const void * src0, device const int32_t * src1, + device const float * src2, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1744,8 +1746,10 @@ kernel void kernel_rope( // simplified from `(ib * n_dims + ic) * inv_ndims` const float cur_rot = inv_ndims*ic - ib; + const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f; + + const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor; - const float theta = theta_0 * pow(freq_base, cur_rot); float cos_theta, sin_theta; rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); diff --git a/ggml.c b/ggml.c index 169f98b23..fa4efccee 100644 --- a/ggml.c +++ b/ggml.c @@ -14311,6 +14311,7 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src2 = dst->src[2]; if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { return; @@ -14370,13 +14371,15 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; - const float* freq_factors = NULL; + const float * freq_factors = NULL; if (is_neox) { - if (dst->src[2] != NULL) { - GGML_ASSERT(dst->src[2]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->src[2]->ne[0] >= n_dims / 2); - freq_factors = (const float*) dst->src[2]->data; + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } + } else { + GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1"); } // backward process uses inverse rotation by cos and sin.