diff --git a/ggml-metal.metal b/ggml-metal.metal index 7e0f4d3f9..3ffe5aa0f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -700,6 +700,7 @@ kernel void kernel_rope( constant float & freq_base, constant float & freq_scale, uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { const int64_t i3 = tgpig[2]; const int64_t i2 = tgpig[1]; @@ -713,7 +714,7 @@ kernel void kernel_rope( const float inv_ndims = -1.f/n_dims; if (!is_neox) { - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 64) { + for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { const float theta = theta_0 * pow(freq_base, inv_ndims*i0); const float cos_theta = cos(theta); @@ -730,7 +731,7 @@ kernel void kernel_rope( } } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { - for (int64_t ic = 2*tiitg; ic < n_dims; ic += 64) { + for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); const float cos_theta = cos(theta);