mtl : add reshape and transpose handling

This commit is contained in:
Georgi Gerganov 2023-05-31 22:38:40 +03:00
parent 1213af76ce
commit 7ca81e9e65
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 11 additions and 7 deletions

View file

@ -258,6 +258,7 @@ int llama_mtl_eval(
switch (gf->nodes[i]->op) {
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
{
// noop
} break;

8
ggml.c
View file

@ -15011,6 +15011,7 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
// create the tensor
// "view" operations are handled differently
// TODO: handle inplac ops - currentl a copy is always made
struct ggml_tensor * tensor = NULL;
@ -15018,8 +15019,11 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
// TODO: implement other view ops
case GGML_OP_RESHAPE:
{
// TODO: implement other dims
tensor = ggml_reshape_3d(*ctx_eval, args[0], ne[0], ne[1], ne[2]);
tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
} break;
case GGML_OP_TRANSPOSE:
{
tensor = ggml_transpose(*ctx_eval, args[0]);
} break;
default:
{

View file

@ -1279,15 +1279,14 @@ static bool llama_eval_internal(
ggml_set_name(Qcur, "Qcur");
ggml_set_name(Kcur, "Kcur");
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(Qcur, "mtl-check");
}
// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
// TODO: TMP !!!!
if (il == 0) {
ggml_set_name(Vcur, "mtl-check");
}
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,