fix naive implementation
This commit is contained in:
parent
b1479dfbc5
commit
0afe47fa5f
1 changed files with 3 additions and 1 deletions
|
@ -207,7 +207,9 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
|
||||||
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);
|
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);
|
||||||
kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0]));
|
kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0]));
|
||||||
kq = ggml_soft_max(ctx0, kq);
|
kq = ggml_soft_max(ctx0, kq);
|
||||||
kq = ggml_mul_mat(ctx0, model.v, kq);
|
kq = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, model.v)), kq);
|
||||||
|
kq = ggml_permute (ctx0, kq, 0, 2, 1, 3);
|
||||||
|
//kq = ggml_cont_2d (ctx0, kq, model.q->ne[0] * model.q->ne[2], model.q->ne[1]);
|
||||||
ggml_build_forward_expand(gf, kq);
|
ggml_build_forward_expand(gf, kq);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue