From ad80e5a4a7a99908dfb38ed025c2a4cba4d3f839 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Sep 2023 19:46:52 +0300 Subject: [PATCH 1/5] llama : add ggml_cont to trigger bug with Metal --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index c97c1462f..097de7221 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2418,11 +2418,11 @@ static struct ggml_cgraph * llm_build_llama( // split cached V into n_head heads struct ggml_tensor * V = - ggml_view_3d(ctx0, kv_self.v, + ggml_cont(ctx0, ggml_view_3d(ctx0, kv_self.v, n_past + N, n_embd_head, n_head_kv, ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il)); offload_func_v(V); ggml_set_name(V, "V"); From 7704db252108d3ec69be4fdcaee4d834ea5e8fa8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Sep 2023 20:48:25 +0300 Subject: [PATCH 2/5] ggml : just in case --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index 38b1155c1..fe06c4067 100644 --- a/ggml.c +++ b/ggml.c @@ -4285,7 +4285,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) { } size_t ggml_nbytes(const struct ggml_tensor * tensor) { - size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type); + size_t nbytes = (tensor->ne[0]*tensor->nb[0])/ggml_blck_size(tensor->type); for (int i = 1; i < GGML_MAX_DIMS; ++i) { nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } From ebd3467cc87d2115f9a6ed293c53879f9174db4e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Sep 2023 20:48:46 +0300 Subject: [PATCH 3/5] metal : more readable kernel --- ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 119fcbeb6..6aa5b9ebf 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -783,11 +783,11 @@ kernel void kernel_cpy_f16_f16( const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; + device const half * src = (device half *) ((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i00*nb0); + + *dst_data = *src; } } From 60c2ef6d92865707c960331528d907dfa50f4b3a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Sep 2023 20:49:09 +0300 Subject: [PATCH 4/5] metal : utilize view_src to see of tensor is a view --- ggml-metal.m | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d0d23442e..32ee2795a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -541,10 +541,7 @@ void ggml_metal_graph_find_concurrency( int64_t data_start = (int64_t) gf->nodes[i]->data; int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]); for (int j = n_start; j < i; j++) { - if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \ - && gf->nodes[j]->op != GGML_OP_VIEW \ - && gf->nodes[j]->op != GGML_OP_TRANSPOSE \ - && gf->nodes[j]->op != GGML_OP_PERMUTE) { + if (nodes_unused[j] && gf->nodes[j]->view_src == NULL) { if (((int64_t)gf->nodes[j]->data) >= data_start + length || \ ((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) { continue; From f3a84b2e0d03aca09fc6fd2cd873c9fd162a0f4e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Sep 2023 21:44:48 +0300 Subject: [PATCH 5/5] llama : better express the KV cache dependencies in the graph --- ggml.c | 2 ++ llama.cpp | 70 +++++++++++++++++++++++++++++++------------------------ 2 files changed, 41 insertions(+), 31 deletions(-) diff --git a/ggml.c b/ggml.c index fe06c4067..696fb3d83 100644 --- a/ggml.c +++ b/ggml.c @@ -5213,6 +5213,8 @@ struct ggml_tensor * ggml_view_tensor( result->nb[i] = src->nb[i]; } + result->op = GGML_OP_VIEW; + return result; } diff --git a/llama.cpp b/llama.cpp index 097de7221..92d4ddafb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2341,45 +2341,53 @@ static struct ggml_cgraph * llm_build_llama( // compute Q and K and RoPE them struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); offload_func_kq(tmpk); - ggml_set_name(tmpk, "tmpk"); + ggml_set_name (tmpk, "tmpk"); struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); offload_func_kq(tmpq); - ggml_set_name(tmpq, "tmpq"); + ggml_set_name (tmpq, "tmpq"); + + struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + offload_func_v(tmpv); + ggml_set_name (tmpv, "tmpv"); struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); - ggml_set_name(Kcur, "Kcur"); + ggml_set_name (Kcur, "Kcur"); struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); - ggml_set_name(Qcur, "Qcur"); + ggml_set_name (Qcur, "Qcur"); + + // compute the transposed [N, n_embd] V matrix + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); + offload_func_v(Vcur); + ggml_set_name (Vcur, "Vcur"); + + struct ggml_tensor * k; + struct ggml_tensor * v; // store key and value to memory { - // compute the transposed [N, n_embd] V matrix + struct ggml_tensor * k_view = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + offload_func_kq(k_view); + ggml_set_name (k_view, "k_view"); - struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - offload_func_v(tmpv); - ggml_set_name(tmpv, "tmpv"); - - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); - offload_func_v(Vcur); - ggml_set_name(Vcur, "Vcur"); - - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); - offload_func_kq(k); - ggml_set_name(k, "k"); - - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v_view = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); - offload_func_v(v); - ggml_set_name(v, "v"); + offload_func_v(v_view); + ggml_set_name (v_view, "v_view"); // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + struct ggml_tensor * k_cpy = ggml_cpy(ctx0, Kcur, k_view); + struct ggml_tensor * v_cpy = ggml_cpy(ctx0, Vcur, v_view); + + // TODO: replace with ggml_dependency / ggml_depends_on + k = ggml_view_tensor(ctx0, kv_self.k); + v = ggml_view_tensor(ctx0, kv_self.v); + k->src[0] = k_cpy; + v->src[0] = v_cpy; } struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -2387,11 +2395,11 @@ static struct ggml_cgraph * llm_build_llama( ggml_set_name(Q, "Q"); struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, + ggml_view_3d(ctx0, k, n_embd_head, n_past + N, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_element_size(k)*n_embd_gqa, + ggml_element_size(k)*n_embd_head, + ggml_element_size(k)*n_embd_gqa*n_ctx*il); offload_func_kq(K); ggml_set_name(K, "K"); @@ -2418,11 +2426,11 @@ static struct ggml_cgraph * llm_build_llama( // split cached V into n_head heads struct ggml_tensor * V = - ggml_cont(ctx0, ggml_view_3d(ctx0, kv_self.v, + ggml_view_3d(ctx0, v, n_past + N, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il)); + ggml_element_size(v)*n_ctx, + ggml_element_size(v)*n_ctx*n_embd_head, + ggml_element_size(v)*n_ctx*n_embd_gqa*il); offload_func_v(V); ggml_set_name(V, "V"); @@ -2434,7 +2442,7 @@ static struct ggml_cgraph * llm_build_llama( // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, v->type, n_past + N, n_embd_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); #endif