From 3e3ed9560cdabc6147cd149ee0ecc81b52aa4441 Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 11 May 2023 19:31:46 +0200 Subject: [PATCH] add parallel batched forward function for baby-llama training --- examples/baby-llama/baby-llama.cpp | 427 +++++++++++++++++++++++++++-- 1 file changed, 403 insertions(+), 24 deletions(-) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 071ae1793..60d81bc4a 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -133,6 +133,10 @@ struct llama_hparams { } }; +uint32_t get_n_ff(const struct llama_hparams* hparams) { + uint32_t n_ff = ((2*(4*hparams->n_embd)/3 + hparams->n_mult - 1)/hparams->n_mult)*hparams->n_mult; + return n_ff; +} struct llama_hparams_lora { uint32_t n_vocab = 32000; @@ -237,7 +241,7 @@ void init_model(struct llama_model * model) { const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; - uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + uint32_t n_ff = get_n_ff(&hparams); struct ggml_context * ctx = model->ctx; @@ -432,13 +436,13 @@ void randomize_model_lora(struct llama_model_lora * model, int seed, float mean, } } -bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model) { +bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model, int n_batch) { const auto & hparams = model->hparams; const int n_ctx = hparams.n_ctx; const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; - const int64_t n_mem = n_layer*n_ctx; + const int64_t n_mem = n_layer*n_ctx*n_batch; const int64_t n_elements = n_embd*n_mem; // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); @@ -467,13 +471,13 @@ bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model) { return true; } -bool init_kv_cache_lora(struct llama_kv_cache* cache, struct llama_model_lora * model) { +bool init_kv_cache_lora(struct llama_kv_cache* cache, struct llama_model_lora * model, int n_batch) { const auto & hparams = model->hparams; const int n_ctx = hparams.n_ctx; const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; - const int64_t n_mem = n_layer*n_ctx; + const int64_t n_mem = n_layer*n_ctx*n_batch; const int64_t n_elements = n_embd*n_mem; // cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); @@ -727,6 +731,323 @@ struct ggml_tensor * forward( return inpL; } +void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) { + GGML_ASSERT(tensor->n_dims == 1); + GGML_ASSERT(tensor->ne[0] == ne0); +} + +void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) { + GGML_ASSERT(tensor->n_dims == 2); + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == ne1); +} + +void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) { + GGML_ASSERT(tensor->n_dims == 3); + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == ne1); + GGML_ASSERT(tensor->ne[2] == ne2); +} + +void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + GGML_ASSERT(tensor->n_dims == 4); + GGML_ASSERT(tensor->ne[0] == ne0); + GGML_ASSERT(tensor->ne[1] == ne1); + GGML_ASSERT(tensor->ne[2] == ne2); + GGML_ASSERT(tensor->ne[3] == ne3); +} + +struct ggml_tensor * forward_batch( + struct llama_model * model, + struct llama_kv_cache * cache, + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * tokens_input, + const int n_tokens, + const int n_past, + const int n_batch) { + + const int N = n_tokens; + + struct llama_kv_cache& kv_self = *cache; + const auto & hparams = model->hparams; + const int n_ctx = hparams.n_ctx; + const int n_vocab = hparams.n_vocab; + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + const int n_head = hparams.n_head; + const int n_rot = hparams.n_rot; + const int n_ff = get_n_ff(&hparams); + + struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); + memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); + + struct ggml_tensor * kc = kv_self.k; + struct ggml_tensor * vc = kv_self.v; + + // inpL shape [n_embd,N*n_batch,1] + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); + assert_shape_2d(inpL, n_embd, N*n_batch); + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + struct ggml_tensor * cur; + + // lctx.use_buf(ctx0, 0); + + // norm + { + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = attention_norm*cur + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].attention_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // self-attention + { + // compute Q and K and RoPE them + // wq shape [n_embd, n_embd, 1, 1] + // wk shape [n_embd, n_embd, 1, 1] + // Qcur shape [n_embd/n_head, n_head, N, n_batch] + // Kcur shape [n_embd/n_head, n_head, N, n_batch] + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0); + assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); + assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); + + // store key and value to memory + { + // compute the transposed [N, n_embd] V matrix + // wv shape [n_embd, n_embd, 1, 1] + // Vcur shape [N, n_embd, n_batch, 1] + struct ggml_tensor * Vcur = ggml_cont(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_mul_mat(ctx0, + model->layers[il].wv, + cur), + n_embd, N, n_batch), + 1, 0, 2, 3)); + + assert_shape_3d(Vcur, N, n_embd, n_batch); + + // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] + // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer] + // k shape [n_embd * N, n_batch] == kv_self.k[:,n_past:n_past+N,:,il] + // v shape [N, n_embd, n_batch, 1] == kv_self.v[:,n_past:n_past+N,:,il] + + /* { + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } //*/ + + kc = ggml_set_2d(ctx0, kc, + ggml_reshape_2d(ctx0, Kcur, n_embd*N, n_batch), + ggml_element_size(kc)*n_embd*n_ctx, + (ggml_element_size(kc)*n_embd)*(il*n_batch*n_ctx + n_past)); + vc = ggml_set_2d(ctx0, vc, + ggml_reshape_2d(ctx0, Vcur, N*n_embd, n_batch), + ggml_element_size(vc)*n_ctx*n_embd, + ggml_element_size(vc)*(n_past + il*n_embd*n_batch*n_ctx)); + + assert_shape_1d(kc, n_embd * n_ctx * n_batch * n_layer); + assert_shape_1d(vc, n_embd * n_ctx * n_batch * n_layer); + } + + // Qcur shape [n_embd/n_head, n_head, N, n_batch] + // Q shape [n_embd/n_head, N, n_head, n_batch] + struct ggml_tensor * Q = + ggml_permute(ctx0, + Qcur, + 0, 2, 1, 3); + assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch); + + // kv_self.k shape [n_embd * n_ctx * n_batch * n_layer] + // K shape [n_embd/n_head, n_past + N, n_head, n_batch] + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_4d(ctx0, + ggml_view_3d(ctx0, + kc, + n_embd, + (n_past + N), + n_batch, + n_embd*ggml_element_size(kc), + n_ctx*n_embd*ggml_element_size(kc), + il*n_batch*n_ctx*n_embd*ggml_element_size(kc)), + n_embd/n_head, n_head, n_past + N, n_batch), + 0, 2, 1, 3); + assert_shape_4d(K, n_embd/n_head, n_past + N, n_head, n_batch); + + // K * Q + // KQ shape [n_past + N, N, n_head, n_batch] + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + assert_shape_4d(KQ, n_past + N, N, n_head, n_batch); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + // KQ_scaled shape [n_past + N, N, n_head, n_batch] + struct ggml_tensor * KQ_scaled = + ggml_scale(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); + assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch); + + // KQ_masked = mask_past(KQ_scaled) + // KQ_masked shape [n_past + N, N, n_head, n_batch] + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); + assert_shape_4d(KQ_masked, n_past + N, N, n_head, n_batch); + + // KQ = soft_max(KQ_masked) + // KQ_soft_max shape [n_past + N, N, n_head, n_batch] + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + assert_shape_4d(KQ_soft_max, n_past + N, N, n_head, n_batch); + + // split cached V into n_head heads + // kv_self.v shape [n_ctx * n_embd * n_batch * n_layer] + // V shape [n_past + N, n_embd/n_head, n_head, n_batch] == kv_self.v[:(n_past+N),:,:,il] + struct ggml_tensor * V = + ggml_view_4d(ctx0, vc, + n_past + N, n_embd/n_head, n_head, n_batch, + ggml_element_size(vc)*n_ctx, + ggml_element_size(vc)*n_ctx*n_embd/n_head, + ggml_element_size(vc)*n_ctx*n_embd, + il*n_batch*n_ctx*n_embd*ggml_element_size(vc)); + assert_shape_4d(V, n_past + N, n_embd/n_head, n_head, n_batch); + + // KQV shape [n_embd/n_head, N, n_head, n_batch] + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + // KQV_merged shape [n_embd/n_head, n_head, N, n_batch] + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch); + // KQV_merged shape + + // cur = KQV_merged.contiguous().view(n_embd, N) + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch); + assert_shape_2d(cur, n_embd, N*n_batch); + // cur = ggml_cpy(ctx0, + // KQV_merged, + // ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection (no bias) + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].wo, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // lctx.use_buf(ctx0, 1); + + // inpFF shape [n_embd,N*n_batch,1,1] + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); + assert_shape_2d(inpFF, n_embd, N*n_batch); + + // feed-forward network + { + // norm + { + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_rms_norm(ctx0, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // cur = ffn_norm*cur + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul(ctx0, + ggml_repeat(ctx0, model->layers[il].ffn_norm, cur), + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // tmp shape [n_ff,N*n_batch,1,1] + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, + model->layers[il].w3, + cur); + assert_shape_2d(tmp, n_ff, N*n_batch); + + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w1, + cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + // SILU activation + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_silu(ctx0, cur); + assert_shape_2d(cur, n_ff, N*n_batch); + + // cur shape [n_ff,N*n_batch,1,1] + cur = ggml_mul(ctx0, cur, tmp); + assert_shape_2d(cur, n_ff, N*n_batch); + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_mul_mat(ctx0, + model->layers[il].w2, + cur); + assert_shape_2d(cur, n_embd, N*n_batch); + } + + // cur shape [n_embd,N*n_batch,1,1] + cur = ggml_add(ctx0, cur, inpFF); + assert_shape_2d(cur, n_embd, N*n_batch); + + // input for next layer + // inpL shape [n_embd,N*n_batch,1,1] + inpL = cur; + assert_shape_2d(inpL, n_embd, N*n_batch); + } + + // norm + { + + // inpL shape [n_embd,N*n_batch,1,1] + inpL = ggml_rms_norm(ctx0, inpL); + assert_shape_2d(inpL, n_embd, N*n_batch); + + // inpL = norm*inpL + // inpL shape [n_embd,N*n_batch,1,1] + inpL = ggml_mul(ctx0, + ggml_repeat(ctx0, model->norm, inpL), + inpL); + + assert_shape_2d(inpL, n_embd, N*n_batch); + + //embeddings = inpL; + } + + // lm_head + // inpL shape [n_vocab,N*n_batch,1,1] + inpL = ggml_mul_mat(ctx0, model->output, inpL); + assert_shape_2d(inpL, n_vocab, N*n_batch); + + { + // inpL shape [n_vocab,N,n_batch,1] + inpL = ggml_reshape_3d(ctx0, + inpL, + n_vocab, N, n_batch); + assert_shape_3d(inpL, n_vocab, N, n_batch); + } + + // run the computation + ggml_build_forward_expand(gf, inpL); + + return inpL; +} + struct ggml_tensor * forward_lora( struct llama_model_lora * model, @@ -1013,6 +1334,40 @@ void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, str } } +void sample_softmax_batch(struct ggml_context * ctx, struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples) { + GGML_ASSERT(best_samples->n_dims == 2); + GGML_ASSERT(logits->n_dims == 3); + GGML_ASSERT(probs->n_dims == 3); + int n_tokens = best_samples->ne[0]; + int n_batch = best_samples->ne[1]; + int n_vocab = logits->ne[0]; + GGML_ASSERT(n_tokens == logits->ne[1]); + GGML_ASSERT(n_batch == logits->ne[2]); + GGML_ASSERT(n_vocab == probs->ne[0]); + GGML_ASSERT(n_tokens == probs->ne[1]); + GGML_ASSERT(n_batch == probs->ne[2]); + + for (int k=0; kne[0], + k*best_samples->nb[1]); + struct ggml_tensor * logits_k = ggml_view_2d(ctx, + logits, + logits->ne[0], + logits->ne[1], + logits->nb[1], + k*logits->nb[2]); + struct ggml_tensor * probs_k = ggml_view_2d(ctx, + probs, + probs->ne[0], + probs->ne[1], + probs->nb[1], + k*probs->nb[2]); + sample_softmax(logits_k, probs_k, best_samples_k); + } +} + void print_row(struct ggml_tensor * probs, int i) { for (int k = 0; k < probs->ne[0]; ++k) { float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k); @@ -1071,6 +1426,30 @@ void get_example_targets(int example_id, struct ggml_tensor * tokens_input, stru } } +void get_example_targets_batch(struct ggml_context * ctx, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) { + GGML_ASSERT(tokens_input->n_dims == 2); + GGML_ASSERT( targets->n_dims == 3); + int n_tokens = tokens_input->ne[0]; + int n_batch = tokens_input->ne[1]; + int n_vocab = targets->ne[0]; + GGML_ASSERT(n_tokens == targets->ne[1]); + GGML_ASSERT(n_batch == targets->ne[2]); + + for (int k=0; kne[0], + k*tokens_input->nb[1]); + struct ggml_tensor * targets_k = ggml_view_2d(ctx, + targets, + targets->ne[0], + targets->ne[1], + targets->nb[1], + k*targets->nb[2]); + get_example_targets(example_id*n_batch + k, tokens_input_k, targets_k); + } +} + void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * targets, int n_shift) { int n_tokens = tokens_input->ne[0]; int n_vocab = targets->ne[0]; @@ -1162,12 +1541,12 @@ int main(int argc, char ** argv) { randomize_model_lora(&model_lora, 1337, 0.0f, 1.0f, -1.0f, +1.0f); */ - + int n_batch = 8; // key + value cache for the self attention struct llama_kv_cache kv_self; printf("init_kv_cache\n"); kv_self.ctx = model.ctx; - init_kv_cache(&kv_self, &model); + init_kv_cache(&kv_self, &model, n_batch); //init_kv_cache_lora(&kv_self, &model_lora); size_t compute_size = 1024ll*1024ll*1024ll; @@ -1187,16 +1566,16 @@ int main(int argc, char ** argv) { struct ggml_context * ctx0 = ggml_init(params); - struct ggml_tensor * before_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * before_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens); - struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens); - struct ggml_tensor * tokens_input1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * tokens_input2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + struct ggml_tensor * before_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); + struct ggml_tensor * before_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); + struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); + struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); + struct ggml_tensor * tokens_input1 = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); + struct ggml_tensor * tokens_input2 = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); // struct ggml_tensor * tokens_input3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); // struct ggml_tensor * tokens_input4 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - struct ggml_tensor * targets1 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens); - struct ggml_tensor * targets2 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens); + struct ggml_tensor * targets1 = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); + struct ggml_tensor * targets2 = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); // struct ggml_tensor * targets3 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens); // struct ggml_tensor * targets4 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens); @@ -1205,24 +1584,24 @@ int main(int argc, char ** argv) { ggml_cgraph gf = {}; gf.n_threads = 1; - get_example_targets(64*ex+0, tokens_input1, targets1); - get_example_targets(64*ex+16, tokens_input2, targets2); + get_example_targets_batch(ctx0, 64*ex+0, tokens_input1, targets1); + // get_example_targets_batch(64*ex+16, tokens_input2, targets2); // get_example_targets(64*ex+32, tokens_input3, targets3); // get_example_targets(64*ex+48, tokens_input4, targets4); // print_matrix(targets); // print_tokens(tokens_input, n_vocab); - struct ggml_tensor * logits1 = forward(&model, &kv_self, ctx0, &gf, tokens_input1, n_tokens, n_past); - struct ggml_tensor * logits2 = forward(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past); + struct ggml_tensor * logits1 = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input1, n_tokens, n_past, n_batch); + // struct ggml_tensor * logits2 = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past, n_batch); // struct ggml_tensor * logits3 = forward(&model, &kv_self, ctx0, &gf, tokens_input3, n_tokens, n_past); // struct ggml_tensor * logits4 = forward(&model, &kv_self, ctx0, &gf, tokens_input4, n_tokens, n_past); // struct ggml_tensor * e = cross_entropy_loss(ctx0, targets1, logits1); - // struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1); + struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1); - struct ggml_tensor * e = ggml_add(ctx0, - square_error_loss(ctx0, targets1, logits1), - square_error_loss(ctx0, targets2, logits2)); + // struct ggml_tensor * e = ggml_add(ctx0, + // square_error_loss(ctx0, targets1, logits1), + // square_error_loss(ctx0, targets2, logits2)); // struct ggml_tensor * e = ggml_add(ctx0, // cross_entropy_loss(ctx0, targets1, logits1), // cross_entropy_loss(ctx0, targets2, logits2)); @@ -1269,7 +1648,7 @@ int main(int argc, char ** argv) { } if (ex % 64 == 0) { - sample_softmax(logits1, after_opt_probs, after_opt_best_samples); + sample_softmax_batch(ctx0, logits1, after_opt_probs, after_opt_best_samples); // printf("probabilities after optimization:\n"); // print_matrix(after_opt_probs); printf("best samples after optimization:\n");