metal : implement RoPE (mode = 2) + avoid ggml_repeat

This commit is contained in:
Georgi Gerganov 2023-08-23 10:41:35 +03:00
parent e3c52bd990
commit 99bb26078f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 25 additions and 13 deletions

View file

@ -571,7 +571,25 @@ kernel void kernel_rope(
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
// TODO: implement
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
theta *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
const float x0 = src[0];
const float x1 = src[n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
}
}
}
}

View file

@ -2445,19 +2445,15 @@ static struct ggml_cgraph * llm_build_falcon(
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
attn_norm = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].attn_norm, attn_norm),
attn_norm),
ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm));
ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm),
model.layers[il].attn_norm_b);
if (model.layers[il].attn_norm_2) { // Falcon-40B
cur = ggml_norm(ctx0, inpL, norm_eps);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur),
cur),
ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur));
ggml_mul(ctx0, cur, model.layers[il].attn_norm_2),
model.layers[il].attn_norm_2_b);
} else { // Falcon 7B
cur = attn_norm;
}
@ -2595,10 +2591,8 @@ static struct ggml_cgraph * llm_build_falcon(
cur = ggml_norm(ctx0, inpL, norm_eps);
cur = ggml_add(ctx0,
ggml_mul(ctx0,
ggml_repeat(ctx0, model.output_norm, cur),
cur),
ggml_repeat(ctx0, model.output_norm_b, cur));
ggml_mul(ctx0, cur, model.output_norm),
model.output_norm_b);
ggml_set_name(cur, "result_norm");
}