From 405c8e90a082fa69a74ab1e7a40c49fd26e99145 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 7 Sep 2023 11:30:14 +0200 Subject: [PATCH] PR suggestion --- ggml-metal.metal | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 7e0f4d3f9..3ffe5aa0f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -700,6 +700,7 @@ kernel void kernel_rope( constant float & freq_base, constant float & freq_scale, uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { const int64_t i3 = tgpig[2]; const int64_t i2 = tgpig[1]; @@ -713,7 +714,7 @@ kernel void kernel_rope( const float inv_ndims = -1.f/n_dims; if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 64) { + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { const float theta = theta_0 * pow(freq_base, inv_ndims*i0); const float cos_theta = cos(theta); @@ -730,7 +731,7 @@ kernel void kernel_rope( } } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 2*tiitg; ic < n_dims; ic += 64) { + for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); const float cos_theta = cos(theta);