From d1b40efcfa2725e48db9182a85a1eca21f33d638 Mon Sep 17 00:00:00 2001 From: Phillip Kravtsov Date: Tue, 26 Sep 2023 11:36:36 -0700 Subject: [PATCH] Correct outputs through masked & softmax'd KQ --- ggml.c | 29 +++++++++---------- llama.cpp | 84 +++++++++++++++++++++++++++---------------------------- 2 files changed, 54 insertions(+), 59 deletions(-) diff --git a/ggml.c b/ggml.c index 3cf682ab9..8eaad0c16 100644 --- a/ggml.c +++ b/ggml.c @@ -11308,21 +11308,18 @@ static void ggml_compute_forward_mul_mat( struct ggml_tensor * dst) { int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - if (strncmp(src1->name, "KQ_soft_max", 11) == 0 && params->ith == 0 - && src1->ne[0] == src1->ne[1]) { - GGML_PRINT("\n KQ_softmax at mul mat time for %s\n", src1->name); + if ( + strncmp(src1->name, "printme", 7) == 0 + && params->ith == 0) { + GGML_PRINT("\nInputs to matmul: %s\n", src1->name); ggml_print_tensor(src1); - if (ggml_nelements(src1) >= 14) { - for (int i=0; i < src1->ne[0] * src1->ne[1]; ++i) { - if (i % src1->ne[1] == 0) { - GGML_PRINT("\n"); - } - GGML_PRINT(" %f ", ((float *)src1->data)[i]); + for (int i=0; i < src1->ne[0] * src1->ne[1]; ++i) { + if (i % src1->ne[0] == 0) { + GGML_PRINT("\n"); } - GGML_PRINT("\n"); - } else { - GGML_PRINT("Not enough elements to print\n"); + GGML_PRINT(" %f ", ((float *)src1->data)[i + (src1->ne[0] * src1->ne[1])]); } + GGML_PRINT("\n"); } GGML_TENSOR_BINARY_OP_LOCALS; @@ -12726,10 +12723,10 @@ static void ggml_compute_forward_rope_f32( if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) { - GGML_PRINT("\nValues at RoPE time for %s\n", src0->name); + if (strncmp(src0->name, "krot", 4) == 0 && params->ith == 0) { + GGML_PRINT("\ninputs of RoPE for %s\n", src0->name); ggml_print_tensor(src0); - int starts[] = {0, 1, 0, 0}; + int starts[] = {0, 0, 1, 0}; ggml_print_tensor_values(src0, starts, 0, 10); } @@ -12860,7 +12857,7 @@ static void ggml_compute_forward_rope_f32( } } } - if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) { + if (strncmp(src0->name, "krot", 4) == 0 && params->ith == 0) { GGML_PRINT("\n dest at RoPE time for %s\n", src0->name); // print shape and strides int starts[4] = {0,0,1,0}; diff --git a/llama.cpp b/llama.cpp index 31f92cad2..7a00fe039 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3791,17 +3791,6 @@ static struct ggml_cgraph * llm_build_adept( } LLAMA_LOG_INFO("\n", __func__); inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); - /* - LLAMA_LOG_INFO("\ninpL:\n", __func__); - if (ggml_nelements(model.tok_embeddings) >= 5) { - for (int i=0; i < 5; ++i) { - LLAMA_LOG_INFO(" %f ", ggml_get_f32_1d(model.tok_embeddings, i)); - } - LLAMA_LOG_INFO("\n"); - } else { - LLAMA_LOG_INFO("Not enough elements to print\n", __func__); - } - */ } else { inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); ggml_allocr_alloc(lctx.alloc, inpL); @@ -3812,7 +3801,7 @@ static struct ggml_cgraph * llm_build_adept( struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); //LLAMA_LOG_INFO("Entering n_layers loop\n", __func__); @@ -3891,18 +3880,19 @@ static struct ggml_cgraph * llm_build_adept( /* offset = */ sizeof(float) * n_embd_head * n_head * N * 2 ) ); + // Q / K layernorm ggml_set_name(tmpq, format("tmpq_%d", il).c_str()); tmpq = ggml_norm(ctx0, tmpq, hparams.f_norm_eps); tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm); - ggml_set_name(tmpq, format("preadd_%d", il).c_str()); tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b); + ggml_set_name(tmpq, format("tmpq_%d", il).c_str()); + log_tensor(tmpq); tmpk = ggml_norm(ctx0, tmpk, hparams.f_norm_eps); tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm); + ggml_set_name(tmpk, format("preadd_%d", il).c_str()); tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b); - ggml_set_name(tmpq, format("tmpq_%d", il).c_str()); ggml_set_name(tmpk, format("tmpk_%d", il).c_str()); - log_tensor(tmpq); log_tensor(tmpk); @@ -3913,12 +3903,13 @@ static struct ggml_cgraph * llm_build_adept( /* nb2 = */ wsize * n_embd_head * n_head, /* offset = */ 0 )); - struct ggml_tensor * qpass = ggml_cont(ctx0, ggml_permute(ctx0, ggml_view_3d( + // get the second half of tmpq, e.g tmpq[n_rot:, :, :] + struct ggml_tensor * qpass = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpq, n_rot, n_head, N, - /* nb1 = */ wsize * n_rot, - /* nb2 = */ wsize * n_rot * n_head, - /* offset = */ (wsize * n_embd_head * n_head) / 2 - ), 2, 1, 0, 3)); + /* nb1 = */ wsize * n_embd_head, + /* nb2 = */ wsize * n_embd_head * n_head, + /* offset = */ wsize * n_rot + )); ggml_set_name(qrot, format("qrot_%d", il).c_str()); ggml_set_name(qpass, format("qpass_%d", il).c_str()); log_tensor(qrot); @@ -3926,18 +3917,16 @@ static struct ggml_cgraph * llm_build_adept( struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpk, n_rot, n_head, N, - /* nb1 = */ wsize * n_rot, - /* nb2 = */ wsize * n_rot * n_head, + /* nb1 = */ wsize * n_embd_head, + /* nb2 = */ wsize * n_embd_head * n_head, /* offset = */ 0 )); - struct ggml_tensor * kpass = ggml_cont(ctx0, - ggml_permute(ctx0, - ggml_view_3d( + struct ggml_tensor * kpass = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpk, n_rot, n_head, N, - /* nb1 = */ wsize * n_rot, - /* nb2 = */ wsize * n_rot * n_head, - /* offset = */ (wsize * n_embd_head * n_head) / 2 - ), 2, 1, 0, 3)); + /* nb1 = */ wsize * n_embd_head, + /* nb2 = */ wsize * n_embd_head * n_head, + /* offset = */ wsize * n_rot + )); ggml_set_name(krot, format("krot_%d", il).c_str()); ggml_set_name(kpass, format("kpass_%d", il).c_str()); log_tensor(krot); @@ -3949,68 +3938,77 @@ static struct ggml_cgraph * llm_build_adept( ), 2, 1, 0, 3 )); + ggml_set_name(qrotated, format("qrotated_%d", il).c_str()); + log_tensor(qrotated); + qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); struct ggml_tensor * krotated = ggml_cont(ctx0, ggml_permute(ctx0, ggml_rope_custom_inplace( ctx0, krot, n_past, n_rot, 2, 0, freq_base, freq_scale ), 2, 1, 0, 3 )); - ggml_set_name(qrotated, format("qrotated_%d", il).c_str()); ggml_set_name(krotated, format("krotated_%d", il).c_str()); - log_tensor(qrotated); log_tensor(krotated); + kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_permute(ctx0, ggml_concat(ctx0, qrotated, qpass), 2, 1, 0, 3)); - struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_permute(ctx0, ggml_concat(ctx0, krotated, kpass), 2, 1, 0, 3)); + struct ggml_tensor * Kcur = ggml_cont(ctx0, + ggml_permute(ctx0, ggml_concat(ctx0, krotated, kpass), + 2, 1, 0, 3) + ); ggml_set_name(Qcur, format("Qcur_%d", il).c_str()); ggml_set_name(Kcur, format("Kcur_%d", il).c_str()); log_tensor(Qcur); log_tensor(Kcur); - + log_tensor(kv_self.k); { + // View v as (N, n_embd) struct ggml_tensor * Vcur = ggml_transpose( - ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N) + ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd, N) ); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, - (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past) + + // Select k from kv cache as 1d view (N * n_embd) + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, + (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past) ); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } - struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); ggml_set_name(Q, "Q"); log_tensor(Q); - // view kv cache? struct ggml_tensor * K = - ggml_view_3d(ctx0, kv_self.k, + ggml_cont(ctx0, ggml_view_3d(ctx0, kv_self.k, n_embd_head, n_past + N, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il)); ggml_set_name(K, "K"); + log_tensor(K); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); ggml_set_name(KQ, "KQ"); - struct ggml_tensor * KQ_scaled = ggml_scale_inplace (ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); ggml_set_name(KQ_scaled, "KQ_scaled"); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); ggml_set_name(KQ_masked, "KQ_mask"); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - ggml_set_name(KQ_soft_max, format("KQ_soft_max_%d", il).c_str()); + ggml_set_name(KQ_soft_max, format("printme_KQ_soft_max_%d", il).c_str()); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v,