metal : implement RoPE (mode = 2) + avoid ggml_repeat
This commit is contained in:
parent
e3c52bd990
commit
99bb26078f
2 changed files with 25 additions and 13 deletions
|
@ -571,7 +571,25 @@ kernel void kernel_rope(
|
|||
dst_data[1] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
} else {
|
||||
// TODO: implement
|
||||
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
||||
for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
||||
const float cos_theta = cos(theta);
|
||||
const float sin_theta = sin(theta);
|
||||
|
||||
theta *= theta_scale;
|
||||
|
||||
const int64_t i0 = ib*n_dims + ic/2;
|
||||
|
||||
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
||||
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
18
llama.cpp
18
llama.cpp
|
@ -2445,19 +2445,15 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
|
||||
|
||||
attn_norm = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].attn_norm, attn_norm),
|
||||
attn_norm),
|
||||
ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm));
|
||||
ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm),
|
||||
model.layers[il].attn_norm_b);
|
||||
|
||||
if (model.layers[il].attn_norm_2) { // Falcon-40B
|
||||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur));
|
||||
ggml_mul(ctx0, cur, model.layers[il].attn_norm_2),
|
||||
model.layers[il].attn_norm_2_b);
|
||||
} else { // Falcon 7B
|
||||
cur = attn_norm;
|
||||
}
|
||||
|
@ -2595,10 +2591,8 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.output_norm, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.output_norm_b, cur));
|
||||
ggml_mul(ctx0, cur, model.output_norm),
|
||||
model.output_norm_b);
|
||||
ggml_set_name(cur, "result_norm");
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue