diff --git a/ggml.c b/ggml.c index 25fa236a2..2b727cb07 100644 --- a/ggml.c +++ b/ggml.c @@ -8831,6 +8831,12 @@ static void ggml_compute_forward_dup( struct ggml_tensor * dst) { if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) { ggml_compute_forward_dup_same_cont(params, src0, dst); + if (strncmp(src0->name, "printme_tmp_", 12) == 0 && params->ith == 0) { + GGML_PRINT("\noutputs of dupe for %s\n", src0->name); + ggml_print_tensor(dst); + int starts[] = {0, 0, 0, 0}; + ggml_print_tensor_values(dst, starts, 0, 10); + } return; } switch (src0->type) { @@ -8847,6 +8853,12 @@ static void ggml_compute_forward_dup( GGML_ASSERT(false); } break; } + if (strncmp(src0->name, "printme_tmp_", 12) == 0 && params->ith == 0) { + GGML_PRINT("\noutputs of dupe for %s\n", src0->name); + ggml_print_tensor(dst); + int starts[] = {0, 0, 0, 0}; + ggml_print_tensor_values(dst, starts, 0, 10); + } } // ggml_compute_forward_add @@ -8926,10 +8938,8 @@ static void ggml_compute_forward_add_f32( ||strncmp(src1->name, "printme", 7) == 0) && params->ith == 0) { GGML_PRINT("\noutputs of add: %s + %s\n", src0->name, src1->name); - ggml_print_tensor(src0); - ggml_print_tensor(src1); ggml_print_tensor(dst); - int starts[] = {0, 0, 0}; + int starts[] = {0, 0, 0, 0}; ggml_print_tensor_values(dst, starts, 0, 10); } } @@ -10918,7 +10928,7 @@ static void ggml_compute_forward_norm_f32( && params->ith == 0) { GGML_PRINT("\nlayernorm inputs for %s\n", src0->name); ggml_print_tensor(src0); - int starts[] = {0, 0, 0}; + int starts[] = {0, 1, 0}; ggml_print_tensor_values(src0, starts, 0, 10); } @@ -11344,19 +11354,36 @@ static void ggml_compute_forward_mul_mat( struct ggml_tensor * dst) { int64_t t0 = ggml_perf_time_us(); UNUSED(t0); - if ( - strncmp(src1->name, "printme", 7) == 0 + if ((strncmp(src0->name, "printme", 7) == 0 || + strncmp(src1->name, "printme", 7) == 0) && params->ith == 0) { GGML_PRINT("\nInputs to matmul: %s\n", src1->name); - ggml_print_tensor(src1); size_t offset = 0;//(src1->ne[0] * src1->ne[1]) - for (int i=0; i < src1->ne[0] * src1->ne[1]; ++i) { - if (i % src1->ne[0] == 0) { + size_t x = src1->ne[0]; + size_t y = src1->ne[1]; + for (int i=0; i < x * y; ++i) { + if (i % x == 0) { GGML_PRINT("\n"); } - GGML_PRINT(" %f ", ((float *)src1->data)[i + offset]); + if (i % x < 4) { + GGML_PRINT(" %f ", ((float *)src1->data)[i + offset]); + } } GGML_PRINT("\n"); + /* + GGML_PRINT("\nInputs to matmul: %s\n", src0->name); + ggml_print_tensor(src0); + if (src0->type == GGML_TYPE_F16) { + for (int i=0; i < src0->ne[0] * src0->ne[1]; ++i) { + if (i % src0->ne[0] == 0) { + GGML_PRINT("\n"); + } + GGML_PRINT(" %f", ((ggml_fp16_t *) src0->data)[i]); + } + } + GGML_PRINT("\n"); + */ + } GGML_TENSOR_BINARY_OP_LOCALS; @@ -11753,6 +11780,12 @@ static void ggml_compute_forward_scale_f32( } ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); } + if (strncmp(src0->name, "printme", 7) == 0 && params->ith == 0) { + GGML_PRINT("\nInputs of scale: %s\n", dst->name); + ggml_print_tensor(src0); + int starts[4] = {0, 0, 0, 0}; + ggml_print_tensor_values(src0, starts, 0, 32); + } } static void ggml_compute_forward_scale( @@ -11910,8 +11943,16 @@ static void ggml_compute_forward_view( const struct ggml_compute_params * params, const struct ggml_tensor * src0) { // NOP - UNUSED(params); - UNUSED(src0); + if (strncmp(src0->name, "cache_k", 7) == 0 && params->ith == 0) { + /* + GGML_PRINT("\noutputs of cache_k for view%s\n", src0->name); + ggml_print_tensor(src0); + int starts[] = {4096 * }; + ggml_print_tensor_values(src0, starts, 0, 10); + */ + } + //UNUSED(params); + //UNUSED(src0); } // ggml_compute_forward_permute @@ -12895,7 +12936,7 @@ static void ggml_compute_forward_rope_f32( if (strncmp(src0->name, "printme", 7) == 0 && params->ith == 0) { GGML_PRINT("\n dest at RoPE time for %s\n", src0->name); // print shape and strides - int starts[4] = {0,0,1,0}; + int starts[3] = {0,0,1}; ggml_print_tensor(dst); ggml_print_tensor_values(dst, starts, 0, 10); } diff --git a/llama.cpp b/llama.cpp index baf3ac0fe..0d4df77a5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2640,8 +2640,8 @@ static bool llama_model_load( } static struct ggml_cgraph * llm_build_llama( - llama_context & lctx, - const llama_batch & batch) { + llama_context & lctx, + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -2668,6 +2668,10 @@ static struct ggml_cgraph * llm_build_llama( const int32_t n_tokens = batch.n_tokens; const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + LLAMA_LOG_INFO("n_kv = %d\n", n_kv); + LLAMA_LOG_INFO("n_tokens = %d\n", n_tokens); + LLAMA_LOG_INFO("n_ctx = %d\n", n_ctx); + LLAMA_LOG_INFO("kvself.n = %d\n", kv_self.n); const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; @@ -2678,11 +2682,9 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; - struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); @@ -2911,6 +2913,7 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); + //ggml_set_name(KQ_soft_max, format("printme_KQ_soft_max_%d", il).c_str()); // split cached V into n_head heads struct ggml_tensor * V = @@ -4077,6 +4080,7 @@ static struct ggml_cgraph * llm_build_persimmon( const auto & hparams = model.hparams; const auto & kv_self = lctx.kv_self; + GGML_ASSERT(!!kv_self.ctx); const int64_t n_embd = hparams.n_embd; @@ -4086,33 +4090,55 @@ static struct ggml_cgraph * llm_build_persimmon( const int64_t n_head = hparams.n_head; const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_gqa = hparams.n_embd_gqa(); - const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; - const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; + const float norm_eps = 1e-5f; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + const size_t n_rot = n_embd_head / 2; + /* + printf("\nnorm_eps is %f\n", norm_eps); + printf("freq_base is %f\n", freq_base); + LLAMA_LOG_INFO("n_kv = %d\n", n_kv); + LLAMA_LOG_INFO("n_tokens = %d\n", n_tokens); + LLAMA_LOG_INFO("n_ctx = %d\n", n_ctx); + LLAMA_LOG_INFO("kvself.n = %d\n", kv_self.n); + */ + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; - GGML_ASSERT(n_embd_head == hparams.n_rot); auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ false, + /*.no_alloc =*/ true, }; - params.no_alloc = true; struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph(ctx0); + struct ggml_tensor * cur; struct ggml_tensor * inpL; + if (batch.token) { struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); } ggml_set_name(inp_tokens, "inp_tokens"); + /* + LLAMA_LOG_INFO("\ninp_tokens: ["); + for (int i = 0; i < n_tokens; ++i) { + LLAMA_LOG_INFO("%d, ", batch.token[i]); + } + LLAMA_LOG_INFO("]\n"); + */ inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); } else { inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); @@ -4121,12 +4147,32 @@ static struct ggml_cgraph * llm_build_persimmon( memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); } } + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_name(KQ_pos, "KQ_pos"); ggml_allocr_alloc(lctx.alloc, KQ_pos); @@ -4136,31 +4182,49 @@ static struct ggml_cgraph * llm_build_persimmon( data[i] = batch.pos[i]; } } + if (do_rope_shift) { + LLAMA_LOG_INFO("do_rope_shift...?\n"); + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * tmp = + ggml_rope_custom_inplace(ctx0, + ggml_view_3d(ctx0, kv_self.k, + n_rot, n_head, n_ctx, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*(n_embd_head*n_ctx*il)// + n_rot) + ), + K_shift, n_rot, 2, 0, freq_base, freq_scale); + ggml_build_forward_expand(gf, tmp); + } + } //LLAMA_LOG_INFO("Entering n_layers loop\n", __func__); for (int il=0; il < n_layer; ++il) { - offload_func_t offload_func = llama_nop; - // Input is (d_model, L) - // Attention + //ggml_format_name(inpL, "printme_layer_input_%d", il); struct ggml_tensor * residual = ggml_dup(ctx0, inpL); - //ggml_format_name(inpL, "printme_layer_inputs_%d", il); { - // input norming - cur = ggml_norm(ctx0, inpL, hparams.f_norm_eps); - cur = ggml_mul( - ctx0, cur, model.layers[il].attn_norm); - //ggml_format_name(cur, "printme_normed_%d", il); + //ggml_format_name(inpL, "printme_inputs_%d", il); + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); + //ggml_format_name(cur, "printme_layernorm_outputs%d", il); cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b); + ggml_format_name(cur, "input_layernorm_%d", il); } - ggml_set_name(cur, "cur"); + // self attention { - // QKV //log_tensor(cur); cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); ggml_format_name(cur, "qkv_preadd_%d", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - // Apply Q, K layernorm - // split qkv GGML_ASSERT(n_head_kv == n_head); ggml_set_name(cur, format("qkv_%d", il).c_str()); @@ -4168,67 +4232,64 @@ static struct ggml_cgraph * llm_build_persimmon( // get it to (d_h, n_head, L, 3) struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il); - struct ggml_tensor * tmpq = ggml_view_3d( + struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - /* nb1 = */ ggml_element_size(tmpqkv_perm) * n_embd_head, - /* nb2 = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - /* offset = */ 0 - ); + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + 0 + )); struct ggml_tensor * tmpk = ggml_view_3d( ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - /* nb1 = */ ggml_element_size(tmpqkv_perm) * n_embd_head, - /* nb2 = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - /* offset = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens ); - - struct ggml_tensor * tmpv = ggml_view_3d( + struct ggml_tensor * tmpv = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpqkv_perm, n_embd_head, n_head, n_tokens, - /* nb1 = */ ggml_element_size(tmpqkv_perm) * n_embd_head, - /* nb2 = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, - /* offset = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2 - ); - //ggml_format_name(tmpq, "printme_tmpq_%d", il); - tmpq = ggml_norm(ctx0, tmpq, hparams.f_norm_eps); + ggml_element_size(tmpqkv_perm) * n_embd_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * n_tokens * 2 + )); + tmpq = ggml_norm(ctx0, tmpq, norm_eps); tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm); + //ggml_format_name(tmpq, "printme_tmpq_%d", il); tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b); - //ggml_format_name(tmpq, "printme_tmpk_%d", il); - tmpk = ggml_norm(ctx0, tmpk, hparams.f_norm_eps); + tmpk = ggml_norm(ctx0, tmpk, norm_eps); tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm); + //ggml_format_name(tmpk, "printme_tmpk_%d", il); tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b); - const size_t n_rot = n_embd_head / 2; struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpq, n_rot, n_head, n_tokens, - /* nb1 = */ ggml_element_size(tmpq) * n_embd_head, - /* nb2 = */ ggml_element_size(tmpq) * n_embd_head * n_head, - /* offset = */ 0 + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + 0 )); - // get the second half of tmpq, e.g tmpq[n_rot:, :, :] - struct ggml_tensor * qpass = ggml_cont(ctx0, ggml_view_3d( - ctx0, tmpq, n_rot, n_head, n_tokens, - /* nb1 = */ ggml_element_size(tmpq) * n_embd_head, - /* nb2 = */ ggml_element_size(tmpq) * n_embd_head * n_head, - /* offset = */ ggml_element_size(tmpq) * n_rot - )); - ggml_set_name(qrot, format("qrot_%d", il).c_str()); - ggml_set_name(qpass, format("qpass_%d", il).c_str()); - //log_tensor(qrot); - //log_tensor(qpass); - - struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d( + struct ggml_tensor * krottmp = ggml_view_3d( ctx0, tmpk, n_rot, n_head, n_tokens, /* nb1 = */ ggml_element_size(tmpk) * n_embd_head, /* nb2 = */ ggml_element_size(tmpk) * n_embd_head * n_head, /* offset = */ 0 + ); + //ggml_format_name(krottmp, "printme_krottmp_%d", il); + struct ggml_tensor * krot = ggml_cont(ctx0, krottmp); + // get the second half of tmpq, e.g tmpq[n_rot:, :, :] + struct ggml_tensor * qpass = ggml_cont(ctx0, ggml_view_3d( + ctx0, tmpq, n_rot, n_head, n_tokens, + ggml_element_size(tmpq) * n_embd_head, + ggml_element_size(tmpq) * n_embd_head * n_head, + ggml_element_size(tmpq) * n_rot )); struct ggml_tensor * kpass = ggml_cont(ctx0, ggml_view_3d( - ctx0, tmpk, n_rot, n_head, n_tokens, - /* nb1 = */ ggml_element_size(tmpk) * n_embd_head, - /* nb2 = */ ggml_element_size(tmpk) * n_embd_head * n_head, - /* offset = */ ggml_element_size(tmpk) * n_rot + ctx0, tmpk, n_rot, n_head, n_tokens, + ggml_element_size(tmpk) * n_embd_head, + ggml_element_size(tmpk) * n_embd_head * n_head, + ggml_element_size(tmpk) * n_rot )); - ggml_set_name(krot, format("krot_%d", il).c_str()); + ggml_set_name(qrot, format("qrot_%d", il).c_str()); + //ggml_set_name(krot, format("printme_krot_%d", il).c_str()); + ggml_set_name(qpass, format("qpass_%d", il).c_str()); ggml_set_name(kpass, format("kpass_%d", il).c_str()); struct ggml_tensor * qrotated = ggml_cont(ctx0, ggml_permute(ctx0, @@ -4239,7 +4300,6 @@ static struct ggml_cgraph * llm_build_persimmon( )); qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); - //ggml_format_name(krot, "printme_krot_%d", il); struct ggml_tensor * krotated = ggml_cont(ctx0, ggml_permute(ctx0, ggml_rope_custom( ctx0, krot, KQ_pos, n_rot, 2, 0, freq_base, freq_scale @@ -4252,18 +4312,38 @@ static struct ggml_cgraph * llm_build_persimmon( ggml_permute(ctx0, ggml_concat(ctx0, qrotated, qpass), 2, 1, 0, 3)); - struct ggml_tensor * Kcur = ggml_cont(ctx0, - ggml_permute(ctx0, ggml_concat(ctx0, krotated, kpass), - 2, 1, 0, 3) - ); + struct ggml_tensor * tmp = ggml_permute(ctx0, ggml_concat(ctx0, krotated, kpass), 2, 1, 0, 3); + //ggml_format_name(tmp, "printme_tmp_%d", il); + struct ggml_tensor * Kcur = ggml_cont(ctx0, tmp); ggml_set_name(Qcur, format("Qcur_%d", il).c_str()); + // kcur appears healthy. ggml_set_name(Kcur, format("Kcur_%d", il).c_str()); { - struct ggml_tensor * Vcur = ggml_transpose( - ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd, n_tokens) - ); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens)); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd, + + struct ggml_tensor * k = ggml_view_1d( + ctx0, kv_self.k, n_tokens*n_embd_gqa, + (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head) + ); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + ggml_set_name(v, "v"); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + /* + struct ggml_tensor * Vcur = ggml_cont(ctx0, + ggml_transpose( + ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd, n_tokens) + )); + ggml_set_name(Vcur, "Vcur"); + struct ggml_tensor * k = ggml_view_1d( + ctx0, kv_self.k, n_tokens*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + kv_head) ); ggml_set_name(k, "k"); @@ -4274,28 +4354,28 @@ static struct ggml_cgraph * llm_build_persimmon( ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + */ } - struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); ggml_set_name(Q, "Q"); //log_tensor(Q); - - struct ggml_tensor * K = - ggml_cont(ctx0, ggml_view_3d(ctx0, kv_self.k, + // For some reason this is all zeros and no balls... + struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, n_embd_head, n_kv, n_head_kv, ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il)); - - ggml_set_name(K, "K"); + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + //ggml_format_name(K, "printme_K_%d", il); + //log_tensor(K); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - ggml_set_name(KQ, "KQ"); - - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + //ggml_set_name(KQ, "KQ"); + //ggml_format_name(KQ, "printme_KQ_%d", il); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); ggml_set_name(KQ_scaled, "KQ_scaled"); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_kv); - ggml_set_name(KQ_masked, "KQ_mask"); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + ggml_set_name(KQ_masked, "KQ_masked"); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); //ggml_set_name(KQ_soft_max, format("printme_KQ_soft_max_%d", il).c_str()); @@ -4314,10 +4394,11 @@ static struct ggml_cgraph * llm_build_persimmon( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); ggml_set_name(KQV_merged, "KQV_merged"); - cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens)); + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); ggml_set_name(cur, "KQV_merged_contiguous"); cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + //ggml_format_name(cur, "printme_wo_%d", il); cur = ggml_add(ctx0, cur, model.layers[il].bo); ggml_set_name(cur, "result_wo"); } @@ -4326,7 +4407,7 @@ static struct ggml_cgraph * llm_build_persimmon( ggml_set_name(residual2, "residual2"); // Norm { - cur = ggml_norm(ctx0, cur, hparams.f_norm_eps); + cur = ggml_norm(ctx0, cur, norm_eps); cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b @@ -4334,8 +4415,7 @@ static struct ggml_cgraph * llm_build_persimmon( } cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); cur = ggml_add(ctx0, cur, model.layers[il].b3); - cur = ggml_relu(ctx0, cur); - cur = ggml_sqr(ctx0, cur); + cur = ggml_sqr(ctx0, ggml_relu(ctx0, cur)); cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); //ggml_format_name(cur, "printme_ffn_down_%d", il); struct ggml_tensor * ffn_out = ggml_add(ctx0, @@ -4348,10 +4428,10 @@ static struct ggml_cgraph * llm_build_persimmon( } cur = inpL; { - cur = ggml_norm(ctx0, cur, hparams.f_norm_eps); - cur = ggml_add(ctx0, - ggml_mul(ctx0, cur, model.output_norm), - model.output_norm_b); + cur = ggml_norm(ctx0, cur, norm_eps); + cur = ggml_mul(ctx0, cur, model.output_norm); + //ggml_set_name(cur, "printme_final"); + cur = ggml_add(ctx0, cur, model.output_norm_b); ggml_set_name(cur, "result_norm"); } cur = ggml_mul_mat(ctx0, model.output, cur);