diff --git a/convert-adept-st-to-gguf.py b/convert-adept-st-to-gguf.py index 4844d5f81..1a6eda8a1 100644 --- a/convert-adept-st-to-gguf.py +++ b/convert-adept-st-to-gguf.py @@ -87,12 +87,16 @@ def main(args_in: list[str] | None = None) -> None: gguf_writer.add_rope_dimension_count(hidden_size // head_count) gguf_writer.add_head_count(head_count) gguf_writer.add_head_count_kv(head_count_kv) + gguf_writer.add_rope_freq_base(hparams['rotary_emb_base']) + gguf_writer.add_layer_norm_eps(hparams['layernorm_epsilon']) if True: tokens, scores, toktypes = handle_tokenizer(dir_model) gguf_writer.add_tokenizer_model('llama') gguf_writer.add_token_list(tokens) gguf_writer.add_token_scores(scores) gguf_writer.add_token_types(toktypes) + gguf_writer.add_bos_token_id(71013) + gguf_writer.add_eos_token_id(71013) tensor_map = gguf.get_tensor_name_map(arch, block_count) print(tensor_map) tensors = {} @@ -105,15 +109,17 @@ def main(args_in: list[str] | None = None) -> None: print(name) # we don't need these - - if name.endswith(".self_attention.rotary_emb.inv_freq"): + if name.endswith(".self_attention.rotary_emb.inv_freq"): continue old_dtype = data.dtype - if 'layernorm.weight' in name: + """ + if 'layernorm.weight' in name or 'word_embeddings.weight' in name: data = data.to(torch.float32) else: if data.dtype != torch.float16 and data.dtype != torch.float32: - data = data.to(torch.float16) + data = data.to(torch.float32) + """ + data = data.to(torch.float32) # check for nans if torch.isnan(data).any(): print("WARNING: tensor '" + name + "' contains NaNs") diff --git a/ggml.c b/ggml.c index a0be068d6..2f02865fc 100644 --- a/ggml.c +++ b/ggml.c @@ -4290,6 +4290,65 @@ void ggml_print_objects(const struct ggml_context * ctx) { GGML_PRINT("%s: --- end ---\n", __func__); } +static void ggml_print_tensor(const struct ggml_tensor * tensor) { + GGML_PRINT("Tensor (null): %s | rank %d | shape (", ggml_type_name(tensor->type), tensor->n_dims); + for (int i=0; in_dims; ++i) { + GGML_PRINT("%lld ", tensor->ne[i]); + } + GGML_PRINT(") | strides ("); + for (int i=0; in_dims; ++i) { + GGML_PRINT("%lld ", tensor->nb[i]); + } + GGML_PRINT(")\n"); +} + +static void ggml_print_tensor_values(const struct ggml_tensor * tensor, int starts[], int dim, int nelts) { + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + GGML_PRINT("printing values for %s[", tensor->name); + for (int i=0; in_dims; ++i) { + if (i!=dim) { + GGML_PRINT("%d", starts[i]); + } else { + if (starts[i] > 0) { + GGML_PRINT("%d:%d", starts[i], starts[i]+nelts); + } else { + GGML_PRINT(":%d", starts[i]+nelts); + } + } + if (in_dims-1) { + GGML_PRINT(","); + } + } + GGML_PRINT("]\n"); + + float *dataPtr = (float *) tensor->data; + + // Compute the offset into data for starts + int offset = 0; + for (int j = 0; j < tensor->n_dims; j++) { + offset += (starts[j] * tensor->nb[j]) / sizeof(float); // Assuming nb[j] is in bytes, divide by sizeof(float) to get float offset. + } + + dataPtr += offset; + + for (int i = 0; i < nelts; i++) { + GGML_PRINT("%f ", *dataPtr); + dataPtr += tensor->nb[dim] / sizeof(float); // Increment by strides for the given dimension. + } + GGML_PRINT("\n"); + /* + char * ptr = (char *)tensor->data; + for (int j=0; jn_dims;j++) { + ptr += tensor->nb[j]*starts[j]; + } + for (int i=0; inb[dim]; + } + GGML_PRINT("\n"); + */ +} + int64_t ggml_nelements(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -6162,6 +6221,7 @@ struct ggml_tensor * ggml_mul_mat( const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); + GGML_PRINT("ggml_mul_mat result shape : (%lld, %lld, %lld, %lld)\n", ne[0], ne[1], ne[2], ne[3]); result->op = GGML_OP_MUL_MAT; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -8823,6 +8883,15 @@ static void ggml_compute_forward_add_f32( } } } + if ((strncmp(src0->name, "preadd", 6) == 0 + || strncmp(src0->name, "qkv_preadd", 10) == 0) + && ith == 0) { + // print name + printf("\nadd outputs for %s\n", src0->name); + ggml_print_tensor(dst); + int starts[] = {0, 3, 0}; + ggml_print_tensor_values(dst, starts, 0, 10); + } } static void ggml_compute_forward_add_f16_f32( @@ -10804,6 +10873,18 @@ static void ggml_compute_forward_norm_f32( } GGML_ASSERT(src0->nb[0] == sizeof(float)); + // If the name starts with "layer_inputs", and we are on thread 0, print the tensor + if ((strncmp(src0->name, "layer_inputs", 12) == 0 + || strncmp(src0->name, "tmpq", 4) == 0) + && params->ith == 0) { + GGML_PRINT("\nlayernorm inputs for %s\n", src0->name); + ggml_print_tensor(src0); + int starts[] = {0, 1, 0}; + ggml_print_tensor_values(src0, starts, 0, 10); + for (int i=64; i<74; ++i) { + GGML_PRINT("%f ", ggml_get_f32_1d(src0, i)); + } + } const int ith = params->ith; const int nth = params->nth; @@ -11227,8 +11308,25 @@ static void ggml_compute_forward_mul_mat( struct ggml_tensor * dst) { int64_t t0 = ggml_perf_time_us(); UNUSED(t0); + if (strncmp(src1->name, "KQ_soft_max", 11) == 0 && params->ith == 0 + && src1->ne[0] == src1->ne[1]) { + GGML_PRINT("\n KQ_softmax at mul mat time for %s\n", src1->name); + ggml_print_tensor(src1); + if (ggml_nelements(src1) >= 14) { + for (int i=0; i < src1->ne[0] * src1->ne[1]; ++i) { + if (i % src1->ne[1] == 0) { + GGML_PRINT("\n"); + } + GGML_PRINT(" %f ", ((float *)src1->data)[i]); + } + GGML_PRINT("\n"); + } else { + GGML_PRINT("Not enough elements to print\n"); + } + } GGML_TENSOR_BINARY_OP_LOCALS; + // If on thread 0, src1 starts with KQ_softmax, print const int ith = params->ith; const int nth = params->nth; @@ -12628,6 +12726,12 @@ static void ggml_compute_forward_rope_f32( if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } + if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) { + GGML_PRINT("\nValues at RoPE time for %s\n", src0->name); + ggml_print_tensor(src0); + int starts[] = {0, 0, 1, 0}; + ggml_print_tensor_values(src0, starts, 1, 10); + } float freq_base; float freq_scale; @@ -12756,6 +12860,13 @@ static void ggml_compute_forward_rope_f32( } } } + if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) { + GGML_PRINT("\n dest at RoPE time for %s\n", src0->name); + // print shape and strides + int starts[4] = {0,0,0,0}; + ggml_print_tensor(dst); + ggml_print_tensor_values(dst, starts, 0, 10); + } } static void ggml_compute_forward_rope_f16( diff --git a/llama.cpp b/llama.cpp index c354f1ef2..a8a724c2c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2337,25 +2337,6 @@ static void llm_load_tensors( const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; auto & layer = model.layers[i]; - /* - input_layernorm.bias torch.Size([4096]) - input_layernorm.weight torch.Size([4096]) - mlp.dense_4h_to_h.bias torch.Size([4096]) - mlp.dense_4h_to_h.weight torch.Size([4096, 16384]) - mlp.dense_h_to_4h.bias torch.Size([16384]) - mlp.dense_h_to_4h.weight torch.Size([16384, 4096]) - post_attention_layernorm.bias torch.Size([4096]) - post_attention_layernorm.weight torch.Size([4096]) - self_attention.dense.bias torch.Size([4096]) - self_attention.dense.weight torch.Size([4096, 4096]) - self_attention.k_layernorm.bias torch.Size([64]) - self_attention.k_layernorm.weight torch.Size([64]) - self_attention.q_layernorm.bias torch.Size([64]) - self_attention.q_layernorm.weight torch.Size([64]) - self_attention.query_key_value.bias torch.Size([12288]) - self_attention.query_key_value.weight torch.Size([12288, 4096]) - self_attention.rotary_emb.inv_freq torch.Size([16]) - */ layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); @@ -3744,6 +3725,20 @@ static struct ggml_cgraph * llm_build_starcoder( return gf; } +static void log_tensor( + ggml_tensor * a +) { + LLAMA_LOG_INFO("Shape of %s is ", a->name); + for (int i = 0; i < a->n_dims; ++i) { + LLAMA_LOG_INFO("%d", a->ne[i]); + if (i < a->n_dims - 1) { + LLAMA_LOG_INFO(","); + } + LLAMA_LOG_INFO(" "); + } + LLAMA_LOG_INFO("\n"); +} + static struct ggml_cgraph * llm_build_adept( llama_context & lctx, const llama_token * tokens, @@ -3760,7 +3755,7 @@ static struct ggml_cgraph * llm_build_adept( GGML_ASSERT(!!kv_self.ctx); const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; + const int64_t n_layer = 1; const int64_t n_ctx = hparams.n_ctx; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head = hparams.n_head; @@ -3785,18 +3780,28 @@ static struct ggml_cgraph * llm_build_adept( struct ggml_tensor * inpL; if (tokens) { struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - ggml_allocr_alloc(lctx.alloc, inp_tokens); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inp_tokens->data, inp_tokens, N*ggml_element_size(inp_tokens)); + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); } + ggml_set_name(inp_tokens, "inp_tokens"); LLAMA_LOG_INFO("Token ids:\n", __func__); for (int i = 0; i < N; ++i) { LLAMA_LOG_INFO(" %d ", tokens[i]); } - ggml_set_name(inp_tokens, "inp_tokens"); - + LLAMA_LOG_INFO("\n", __func__); inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + /* + LLAMA_LOG_INFO("\ninpL:\n", __func__); + if (ggml_nelements(model.tok_embeddings) >= 5) { + for (int i=0; i < 5; ++i) { + LLAMA_LOG_INFO(" %f ", ggml_get_f32_1d(model.tok_embeddings, i)); + } + LLAMA_LOG_INFO("\n"); + } else { + LLAMA_LOG_INFO("Not enough elements to print\n", __func__); + } + */ } else { inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); ggml_allocr_alloc(lctx.alloc, inpL); @@ -3804,7 +3809,6 @@ static struct ggml_cgraph * llm_build_adept( memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); } } - // Log all of the token ids sequentially struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { @@ -3813,63 +3817,159 @@ static struct ggml_cgraph * llm_build_adept( ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); //LLAMA_LOG_INFO("Entering n_layers loop\n", __func__); for (int il=0; il < n_layer; ++il) { - struct ggml_tensor * attn_norm; offload_func_t offload_func = llama_nop; + // Input is (d_model, L) // Attention + struct ggml_tensor * residual = inpL; + ggml_set_name(residual, format((char*)"layer_inputs_%d", il).c_str()); { // input norming - attn_norm = ggml_norm(ctx0, inpL, hparams.f_norm_eps); - attn_norm = ggml_add(ctx0, ggml_mul( - ctx0, attn_norm, model.layers[il].attn_norm), + cur = ggml_norm(ctx0, inpL, hparams.f_norm_eps); + cur = ggml_add(ctx0, ggml_mul( + ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b); - - // QKV + bias - cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm); + } + ggml_set_name(cur, "cur"); + { + // QKV + log_tensor(cur); + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + // 3 * d_model, L + // or 2 * n_head_kv + n_embd_head, L + // + bias + ggml_format_name(cur, "qkv_preadd_%d", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - const size_t wsize = ggml_type_size(cur->type); // Apply Q, K layernorm + // Where is the Q/K/V? it's in order. Hopefully... + // So q has offset 0. + // And split into heads + // -> (d_h, n_head, L) + const size_t wsize = ggml_type_size(cur->type); + GGML_ASSERT(n_head_kv == n_head); + LLAMA_LOG_INFO("N: %d\n", N); + ggml_set_name(cur, format("qkv_%d", il).c_str()); + log_tensor(cur); + + // cur is (3 * d_head * n_head, N) + struct ggml_tensor * tmpqkv = ggml_view_4d( + ctx0, cur, n_embd_head, 3, n_head, N, + /* nb1 = */ wsize * n_embd_head, + /* nb2 = */ wsize * n_embd_head * 3, + /* nb3 = */ wsize * n_embd_head * 3 * n_head, + /* offset = */ 0 + ); + // get it to (d_h, n_head, L, 3) + struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); + ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il); + log_tensor(tmpqkv_perm); struct ggml_tensor * tmpq = ggml_cont( - ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head, N, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - 0 + ctx0, + ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, N, + /* nb1 = */ sizeof(float) * n_embd_head, + /* nb2 = */ sizeof(float) * n_embd_head * n_head, + /* offset = */ 0 ) ); struct ggml_tensor * tmpk = ggml_cont( - ctx0, ggml_view_3d( - ctx0, cur, n_embd_head, n_head, N, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - wsize * n_embd_head * n_head + ctx0, + ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, N, + /* nb1 = */ sizeof(float) * n_embd_head, + /* nb2 = */ sizeof(float) * n_embd_head * n_head, + /* offset = */ sizeof(float) * n_embd_head * n_head * N ) ); + struct ggml_tensor * tmpv = ggml_cont( + ctx0, + ggml_view_3d( + ctx0, tmpqkv_perm, n_embd_head, n_head, N, + /* nb1 = */ sizeof(float) * n_embd_head, + /* nb2 = */ sizeof(float) * n_embd_head * n_head, + /* offset = */ sizeof(float) * n_embd_head * n_head * N * 2 + ) + ); + ggml_set_name(tmpq, format("tmpq_%d", il).c_str()); + tmpq = ggml_norm(ctx0, tmpq, hparams.f_norm_eps); + tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm); + ggml_set_name(tmpq, format("preadd_%d", il).c_str()); + tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b); + tmpk = ggml_norm(ctx0, tmpk, hparams.f_norm_eps); tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm); tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b); - - tmpq = ggml_norm(ctx0, tmpq, hparams.f_norm_eps); - tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm); - tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b); + ggml_set_name(tmpq, format("tmpq_%d", il).c_str()); + ggml_set_name(tmpk, format("tmpk_%d", il).c_str()); + log_tensor(tmpq); + log_tensor(tmpk); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace( - ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale - ); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace( - ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale - ); + const size_t n_rot = n_embd_head / 2; + struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d( + ctx0, tmpq, n_rot, n_head, N, + /* nb1 = */ wsize * n_embd_head, + /* nb2 = */ wsize * n_embd_head * n_head, + /* offset = */ 0 + )); + struct ggml_tensor * qpass = ggml_cont(ctx0, ggml_permute(ctx0, ggml_view_3d( + ctx0, tmpq, n_rot, n_head, N, + /* nb1 = */ wsize * n_rot, + /* nb2 = */ wsize * n_rot * n_head, + /* offset = */ (wsize * n_embd_head * n_head) / 2 + ), 2, 1, 0, 3)); + ggml_set_name(qrot, format("qrot_%d", il).c_str()); + ggml_set_name(qpass, format("qpass_%d", il).c_str()); + log_tensor(qrot); + log_tensor(qpass); - struct ggml_tensor * tmpv = ggml_view_3d( - ctx0, cur, n_embd_head, n_head_kv, N, - wsize * n_embd_head, - wsize * n_embd_head * (n_head + 2 * n_head_kv), - wsize * n_embd_head * (n_head + n_head_kv)); + struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d( + ctx0, tmpk, n_rot, n_head, N, + /* nb1 = */ wsize * n_rot, + /* nb2 = */ wsize * n_rot * n_head, + /* offset = */ 0 + )); + struct ggml_tensor * kpass = ggml_cont(ctx0, + ggml_permute(ctx0, + ggml_view_3d( + ctx0, tmpk, n_rot, n_head, N, + /* nb1 = */ wsize * n_rot, + /* nb2 = */ wsize * n_rot * n_head, + /* offset = */ (wsize * n_embd_head * n_head) / 2 + ), 2, 1, 0, 3)); + ggml_set_name(krot, format("krot_%d", il).c_str()); + ggml_set_name(kpass, format("kpass_%d", il).c_str()); + log_tensor(krot); + log_tensor(kpass); + + struct ggml_tensor * qrotated = ggml_cont(ctx0, ggml_permute(ctx0, + ggml_rope_custom_inplace( + ctx0, qrot, n_past, n_rot, 0, 0, freq_base, freq_scale + ), + 2, 1, 0, 3 + )); + struct ggml_tensor * krotated = ggml_cont(ctx0, ggml_permute(ctx0, + ggml_rope_custom_inplace( + ctx0, krot, n_past, n_rot, 0, 0, freq_base, freq_scale + ), + 2, 1, 0, 3 + )); + ggml_set_name(qrotated, format("qrotated_%d", il).c_str()); + ggml_set_name(krotated, format("krotated_%d", il).c_str()); + log_tensor(qrotated); + log_tensor(krotated); + struct ggml_tensor * Qcur = ggml_cont(ctx0, + ggml_permute(ctx0, + ggml_concat(ctx0, qrotated, qpass), + 2, 1, 0, 3)); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_permute(ctx0, ggml_concat(ctx0, krotated, kpass), 2, 1, 0, 3)); + ggml_set_name(Qcur, format("Qcur_%d", il).c_str()); + ggml_set_name(Kcur, format("Kcur_%d", il).c_str()); + log_tensor(Qcur); + log_tensor(Kcur); { - // Set kv cache elements? struct ggml_tensor * Vcur = ggml_transpose( ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N) ); @@ -3886,11 +3986,11 @@ static struct ggml_cgraph * llm_build_adept( ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } - //LLAMA_LOG_INFO("3889\n", __func__); struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); ggml_set_name(Q, "Q"); + log_tensor(Q); - // index into kv cache? + // view kv cache? struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, n_embd_head, n_past + N, n_head_kv, @@ -3907,12 +4007,11 @@ static struct ggml_cgraph * llm_build_adept( ggml_set_name(KQ_scaled, "KQ_scaled"); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); - ggml_set_name(KQ_masked, "KQ_soft_max"); + ggml_set_name(KQ_masked, "KQ_mask"); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - ggml_set_name(KQ_soft_max, "KQ_soft_max"); + ggml_set_name(KQ_soft_max, format("KQ_soft_max_%d", il).c_str()); - //LLAMA_LOG_INFO("3915\n", __func__); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, n_past + N, n_embd_head, n_head_kv, @@ -3932,15 +4031,19 @@ static struct ggml_cgraph * llm_build_adept( cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); ggml_set_name(cur, "result_wo"); - //LLAMA_LOG_INFO("EoWo\n", __func__); } - struct ggml_tensor * attn_out = cur; + cur = ggml_add(ctx0, residual, cur); + residual = cur; + ggml_set_name(residual, "residual"); { - struct ggml_tensor * inpFF = attn_norm; + struct ggml_tensor * inpFF = cur; // Norm { cur = ggml_norm(ctx0, inpFF, hparams.f_norm_eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].ffn_norm), + model.layers[il].ffn_norm_b + ); } cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); // Squared ReLU @@ -3948,31 +4051,22 @@ static struct ggml_cgraph * llm_build_adept( cur = ggml_mul(ctx0, cur, cur); cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); } - cur = ggml_add(ctx0, cur, attn_out); + cur = ggml_add(ctx0, cur, residual); ggml_set_name(cur, "inpFF_+_attn_out"); inpL = cur; - //LLAMA_LOG_INFO("EoL\n", __func__); } - //LLAMA_LOG_INFO("Exited from n_layers loop\n", __func__); cur = inpL; { - //LLAMA_LOG_INFO("norm\n", __func__); cur = ggml_norm(ctx0, cur, hparams.f_norm_eps); - //LLAMA_LOG_INFO("ggml_norm\n", __func__); cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); - //LLAMA_LOG_INFO("result_norm\n", __func__); ggml_set_name(cur, "result_norm"); } - //LLAMA_LOG_INFO("matmul\n", __func__); cur = ggml_mul_mat(ctx0, model.output, cur); ggml_set_name(cur, "result_output"); - //LLAMA_LOG_INFO("bf expand\n", __func__); ggml_build_forward_expand(gf, cur); - //LLAMA_LOG_INFO("Freeing ctx0\n", __func__); ggml_free(ctx0); - //LLAMA_LOG_INFO("Exiting fun\n", __func__); return gf; }