From 3035c2db91377a8e2bf77b8e85e58badc3308e79 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 May 2024 13:26:12 +0300 Subject: [PATCH] metal : better rope implementation ggml-ci --- ggml-metal.m | 49 ++++++---- ggml-metal.metal | 192 +++++++++++++++++++++---------------- ggml.c | 24 ++--- tests/test-backend-ops.cpp | 16 ++-- 4 files changed, 152 insertions(+), 129 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 73a084681..6a38f046f 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -172,8 +172,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_ROPE_F32, - GGML_METAL_KERNEL_TYPE_ROPE_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_UPSCALE_F32, @@ -626,8 +628,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); @@ -2303,17 +2307,21 @@ static enum ggml_status ggml_metal_graph_compute( const bool is_neox = mode & 2; - if (!is_neox) { - GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox"); - } - id pipeline = nil; - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break; - default: GGML_ASSERT(false); - }; + if (!is_neox) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ASSERT(false); + }; + } else { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ASSERT(false); + }; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -2342,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&mode length:sizeof( int) atIndex:22]; - [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:24]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:25]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:28]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:29]; + [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 0cb85e1a5..a86a7cff2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, - thread float * cos_theta, thread float * sin_theta -) { + thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; @@ -1684,43 +1683,8 @@ static void rope_yarn_corr_dims( dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); } -typedef void (rope_t)( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant int & n_orig_ctx, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]); - template -kernel void kernel_rope( +kernel void kernel_rope_norm( device const void * src0, device const int32_t * src1, device const float * src2, @@ -1743,7 +1707,6 @@ kernel void kernel_rope( constant uint64_t & nb3, constant int & n_past, constant int & n_dims, - constant int & mode, constant int & n_orig_ctx, constant float & freq_base, constant float & freq_scale, @@ -1758,69 +1721,130 @@ kernel void kernel_rope( const int64_t i2 = tgpig[1]; const int64_t i1 = tgpig[0]; - const bool is_neox = mode & 2; + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); + + device const int32_t * pos = src1; + + const float theta_base = (float) pos[i2]; + const float inv_ndims = -1.f/n_dims; + + float cos_theta; + float sin_theta; + + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + const float x0 = src[0]; + const float x1 = src[1]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[1] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_neox( + device const void * src0, + device const int32_t * src1, + device const float * src2, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & n_orig_ctx, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int64_t i3 = tgpig[2]; + const int64_t i2 = tgpig[1]; + const int64_t i1 = tgpig[0]; float corr_dims[2]; rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); device const int32_t * pos = src1; - const int64_t p = pos[i2]; - - const float theta_base = (float)p; + const float theta_base = (float) pos[i2]; const float inv_ndims = -1.f/n_dims; - if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + float cos_theta; + float sin_theta; + + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { + if (i0 < n_dims) { + const int64_t ic = i0/2; + const float theta = theta_base * pow(freq_base, inv_ndims*i0); - float cos_theta, sin_theta; - rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const T x0 = src[0]; - const T x1 = src[1]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[1] = x0*sin_theta + x1*cos_theta; - } - } else { - for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) { - if (ic < n_dims) { - const int64_t i0 = ic/2; - - const float freq_factor = src2 != src0 ? src2[i0] : 1.0f; - - const float theta = theta_base * pow(freq_base, inv_ndims*ic); - - float cos_theta, sin_theta; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta); - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; - } else { - const int64_t i0 = ic; - - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } + dst_data[0] = src[0]; + dst_data[1] = src[1]; } } } -template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; -template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; +typedef decltype(kernel_rope_norm) kernel_rope_norm_t; +typedef decltype(kernel_rope_neox) kernel_rope_neox_t; + +template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; +template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; + +template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; typedef void (im2col_t)( device const float * x, diff --git a/ggml.c b/ggml.c index 5b204e1a5..328484d77 100644 --- a/ggml.c +++ b/ggml.c @@ -14341,14 +14341,10 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const float * freq_factors = NULL; - if (is_neox) { - if (src2 != NULL) { - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - } else { - GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } // backward process uses inverse rotation by cos and sin. @@ -14474,14 +14470,10 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & 2; const float * freq_factors = NULL; - if (is_neox) { - if (src2 != NULL) { - GGML_ASSERT(src2->type == GGML_TYPE_F32); - GGML_ASSERT(src2->ne[0] >= n_dims / 2); - freq_factors = (const float *) src2->data; - } - } else { - GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox"); + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; } // backward process uses inverse rotation by cos and sin. diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 99b57d0fb..58a3bd6a4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2236,15 +2236,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (float ef : { 0.0f, 0.7465f }) { for (float af : { 1.0f, 1.4245f }) { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { - // TODO: ff not supported yet for !neox - test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B - if (all) { - test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B - } - for (bool ff : {false, true}) { // freq_factors + test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B + + if (all) { + test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B + } + if (all) { test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)