PR suggestion
This commit is contained in:
parent
a68e1a5656
commit
405c8e90a0
1 changed files with 3 additions and 2 deletions
|
@ -700,6 +700,7 @@ kernel void kernel_rope(
|
||||||
constant float & freq_base,
|
constant float & freq_base,
|
||||||
constant float & freq_scale,
|
constant float & freq_scale,
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
|
uint3 tptg[[threads_per_threadgroup]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
const int64_t i3 = tgpig[2];
|
const int64_t i3 = tgpig[2];
|
||||||
const int64_t i2 = tgpig[1];
|
const int64_t i2 = tgpig[1];
|
||||||
|
@ -713,7 +714,7 @@ kernel void kernel_rope(
|
||||||
const float inv_ndims = -1.f/n_dims;
|
const float inv_ndims = -1.f/n_dims;
|
||||||
|
|
||||||
if (!is_neox) {
|
if (!is_neox) {
|
||||||
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 64) {
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
||||||
|
|
||||||
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
||||||
const float cos_theta = cos(theta);
|
const float cos_theta = cos(theta);
|
||||||
|
@ -730,7 +731,7 @@ kernel void kernel_rope(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||||
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 64) {
|
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
|
||||||
|
|
||||||
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
|
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
|
||||||
const float cos_theta = cos(theta);
|
const float cos_theta = cos(theta);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue