add parallel batched forward function for baby-llama training

This commit is contained in:
xaedes 2023-05-11 19:31:46 +02:00
parent 6ca682b19d
commit 3e3ed9560c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -133,6 +133,10 @@ struct llama_hparams {
}
};
uint32_t get_n_ff(const struct llama_hparams* hparams) {
uint32_t n_ff = ((2*(4*hparams->n_embd)/3 + hparams->n_mult - 1)/hparams->n_mult)*hparams->n_mult;
return n_ff;
}
struct llama_hparams_lora {
uint32_t n_vocab = 32000;
@ -237,7 +241,7 @@ void init_model(struct llama_model * model) {
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_vocab = hparams.n_vocab;
uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
uint32_t n_ff = get_n_ff(&hparams);
struct ggml_context * ctx = model->ctx;
@ -432,13 +436,13 @@ void randomize_model_lora(struct llama_model_lora * model, int seed, float mean,
}
}
bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model) {
bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model, int n_batch) {
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int64_t n_mem = n_layer*n_ctx;
const int64_t n_mem = n_layer*n_ctx*n_batch;
const int64_t n_elements = n_embd*n_mem;
// cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
@ -467,13 +471,13 @@ bool init_kv_cache(struct llama_kv_cache* cache, struct llama_model * model) {
return true;
}
bool init_kv_cache_lora(struct llama_kv_cache* cache, struct llama_model_lora * model) {
bool init_kv_cache_lora(struct llama_kv_cache* cache, struct llama_model_lora * model, int n_batch) {
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int64_t n_mem = n_layer*n_ctx;
const int64_t n_mem = n_layer*n_ctx*n_batch;
const int64_t n_elements = n_embd*n_mem;
// cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
@ -727,6 +731,323 @@ struct ggml_tensor * forward(
return inpL;
}
void assert_shape_1d(struct ggml_tensor * tensor, int64_t ne0) {
GGML_ASSERT(tensor->n_dims == 1);
GGML_ASSERT(tensor->ne[0] == ne0);
}
void assert_shape_2d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1) {
GGML_ASSERT(tensor->n_dims == 2);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[1] == ne1);
}
void assert_shape_3d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2) {
GGML_ASSERT(tensor->n_dims == 3);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[1] == ne1);
GGML_ASSERT(tensor->ne[2] == ne2);
}
void assert_shape_4d(struct ggml_tensor * tensor, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
GGML_ASSERT(tensor->n_dims == 4);
GGML_ASSERT(tensor->ne[0] == ne0);
GGML_ASSERT(tensor->ne[1] == ne1);
GGML_ASSERT(tensor->ne[2] == ne2);
GGML_ASSERT(tensor->ne[3] == ne3);
}
struct ggml_tensor * forward_batch(
struct llama_model * model,
struct llama_kv_cache * cache,
struct ggml_context * ctx0,
struct ggml_cgraph * gf,
struct ggml_tensor * tokens_input,
const int n_tokens,
const int n_past,
const int n_batch) {
const int N = n_tokens;
struct llama_kv_cache& kv_self = *cache;
const auto & hparams = model->hparams;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_head = hparams.n_head;
const int n_rot = hparams.n_rot;
const int n_ff = get_n_ff(&hparams);
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;
// inpL shape [n_embd,N*n_batch,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
assert_shape_2d(inpL, n_embd, N*n_batch);
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
struct ggml_tensor * cur;
// lctx.use_buf(ctx0, 0);
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur
cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model->layers[il].attention_norm, cur),
cur);
assert_shape_2d(cur, n_embd, N*n_batch);
}
// self-attention
{
// compute Q and K and RoPE them
// wq shape [n_embd, n_embd, 1, 1]
// wk shape [n_embd, n_embd, 1, 1]
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0);
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
// wv shape [n_embd, n_embd, 1, 1]
// Vcur shape [N, n_embd, n_batch, 1]
struct ggml_tensor * Vcur = ggml_cont(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_mul_mat(ctx0,
model->layers[il].wv,
cur),
n_embd, N, n_batch),
1, 0, 2, 3));
assert_shape_3d(Vcur, N, n_embd, n_batch);
// kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
// kv_self.v shape [n_ctx * n_embd * n_batch * n_layer]
// k shape [n_embd * N, n_batch] == kv_self.k[:,n_past:n_past+N,:,il]
// v shape [N, n_embd, n_batch, 1] == kv_self.v[:,n_past:n_past+N,:,il]
/* {
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
} //*/
kc = ggml_set_2d(ctx0, kc,
ggml_reshape_2d(ctx0, Kcur, n_embd*N, n_batch),
ggml_element_size(kc)*n_embd*n_ctx,
(ggml_element_size(kc)*n_embd)*(il*n_batch*n_ctx + n_past));
vc = ggml_set_2d(ctx0, vc,
ggml_reshape_2d(ctx0, Vcur, N*n_embd, n_batch),
ggml_element_size(vc)*n_ctx*n_embd,
ggml_element_size(vc)*(n_past + il*n_embd*n_batch*n_ctx));
assert_shape_1d(kc, n_embd * n_ctx * n_batch * n_layer);
assert_shape_1d(vc, n_embd * n_ctx * n_batch * n_layer);
}
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
// Q shape [n_embd/n_head, N, n_head, n_batch]
struct ggml_tensor * Q =
ggml_permute(ctx0,
Qcur,
0, 2, 1, 3);
assert_shape_4d(Q, n_embd/n_head, N, n_head, n_batch);
// kv_self.k shape [n_embd * n_ctx * n_batch * n_layer]
// K shape [n_embd/n_head, n_past + N, n_head, n_batch]
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_4d(ctx0,
ggml_view_3d(ctx0,
kc,
n_embd,
(n_past + N),
n_batch,
n_embd*ggml_element_size(kc),
n_ctx*n_embd*ggml_element_size(kc),
il*n_batch*n_ctx*n_embd*ggml_element_size(kc)),
n_embd/n_head, n_head, n_past + N, n_batch),
0, 2, 1, 3);
assert_shape_4d(K, n_embd/n_head, n_past + N, n_head, n_batch);
// K * Q
// KQ shape [n_past + N, N, n_head, n_batch]
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
assert_shape_4d(KQ, n_past + N, N, n_head, n_batch);
// KQ_scaled = KQ / sqrt(n_embd/n_head)
// KQ_scaled shape [n_past + N, N, n_head, n_batch]
struct ggml_tensor * KQ_scaled =
ggml_scale(ctx0,
KQ,
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);
// KQ_masked = mask_past(KQ_scaled)
// KQ_masked shape [n_past + N, N, n_head, n_batch]
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
assert_shape_4d(KQ_masked, n_past + N, N, n_head, n_batch);
// KQ = soft_max(KQ_masked)
// KQ_soft_max shape [n_past + N, N, n_head, n_batch]
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
assert_shape_4d(KQ_soft_max, n_past + N, N, n_head, n_batch);
// split cached V into n_head heads
// kv_self.v shape [n_ctx * n_embd * n_batch * n_layer]
// V shape [n_past + N, n_embd/n_head, n_head, n_batch] == kv_self.v[:(n_past+N),:,:,il]
struct ggml_tensor * V =
ggml_view_4d(ctx0, vc,
n_past + N, n_embd/n_head, n_head, n_batch,
ggml_element_size(vc)*n_ctx,
ggml_element_size(vc)*n_ctx*n_embd/n_head,
ggml_element_size(vc)*n_ctx*n_embd,
il*n_batch*n_ctx*n_embd*ggml_element_size(vc));
assert_shape_4d(V, n_past + N, n_embd/n_head, n_head, n_batch);
// KQV shape [n_embd/n_head, N, n_head, n_batch]
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
assert_shape_4d(KQV, n_embd/n_head, N, n_head, n_batch);
// KQV_merged = KQV.permute(0, 2, 1, 3)
// KQV_merged shape [n_embd/n_head, n_head, N, n_batch]
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
assert_shape_4d(KQV_merged, n_embd/n_head, n_head, N, n_batch);
// KQV_merged shape
// cur = KQV_merged.contiguous().view(n_embd, N)
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N*n_batch);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ggml_cpy(ctx0,
// KQV_merged,
// ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
// projection (no bias)
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_mul_mat(ctx0,
model->layers[il].wo,
cur);
assert_shape_2d(cur, n_embd, N*n_batch);
}
// lctx.use_buf(ctx0, 1);
// inpFF shape [n_embd,N*n_batch,1,1]
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
assert_shape_2d(inpFF, n_embd, N*n_batch);
// feed-forward network
{
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_mul(ctx0,
ggml_repeat(ctx0, model->layers[il].ffn_norm, cur),
cur);
assert_shape_2d(cur, n_embd, N*n_batch);
}
// tmp shape [n_ff,N*n_batch,1,1]
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
model->layers[il].w3,
cur);
assert_shape_2d(tmp, n_ff, N*n_batch);
// cur shape [n_ff,N*n_batch,1,1]
cur = ggml_mul_mat(ctx0,
model->layers[il].w1,
cur);
assert_shape_2d(cur, n_ff, N*n_batch);
// SILU activation
// cur shape [n_ff,N*n_batch,1,1]
cur = ggml_silu(ctx0, cur);
assert_shape_2d(cur, n_ff, N*n_batch);
// cur shape [n_ff,N*n_batch,1,1]
cur = ggml_mul(ctx0, cur, tmp);
assert_shape_2d(cur, n_ff, N*n_batch);
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_mul_mat(ctx0,
model->layers[il].w2,
cur);
assert_shape_2d(cur, n_embd, N*n_batch);
}
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_add(ctx0, cur, inpFF);
assert_shape_2d(cur, n_embd, N*n_batch);
// input for next layer
// inpL shape [n_embd,N*n_batch,1,1]
inpL = cur;
assert_shape_2d(inpL, n_embd, N*n_batch);
}
// norm
{
// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL
// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model->norm, inpL),
inpL);
assert_shape_2d(inpL, n_embd, N*n_batch);
//embeddings = inpL;
}
// lm_head
// inpL shape [n_vocab,N*n_batch,1,1]
inpL = ggml_mul_mat(ctx0, model->output, inpL);
assert_shape_2d(inpL, n_vocab, N*n_batch);
{
// inpL shape [n_vocab,N,n_batch,1]
inpL = ggml_reshape_3d(ctx0,
inpL,
n_vocab, N, n_batch);
assert_shape_3d(inpL, n_vocab, N, n_batch);
}
// run the computation
ggml_build_forward_expand(gf, inpL);
return inpL;
}
struct ggml_tensor * forward_lora(
struct llama_model_lora * model,
@ -1013,6 +1334,40 @@ void sample_softmax(struct ggml_tensor * logits, struct ggml_tensor * probs, str
}
}
void sample_softmax_batch(struct ggml_context * ctx, struct ggml_tensor * logits, struct ggml_tensor * probs, struct ggml_tensor * best_samples) {
GGML_ASSERT(best_samples->n_dims == 2);
GGML_ASSERT(logits->n_dims == 3);
GGML_ASSERT(probs->n_dims == 3);
int n_tokens = best_samples->ne[0];
int n_batch = best_samples->ne[1];
int n_vocab = logits->ne[0];
GGML_ASSERT(n_tokens == logits->ne[1]);
GGML_ASSERT(n_batch == logits->ne[2]);
GGML_ASSERT(n_vocab == probs->ne[0]);
GGML_ASSERT(n_tokens == probs->ne[1]);
GGML_ASSERT(n_batch == probs->ne[2]);
for (int k=0; k<n_batch; ++k) {
struct ggml_tensor * best_samples_k = ggml_view_1d(ctx,
best_samples,
best_samples->ne[0],
k*best_samples->nb[1]);
struct ggml_tensor * logits_k = ggml_view_2d(ctx,
logits,
logits->ne[0],
logits->ne[1],
logits->nb[1],
k*logits->nb[2]);
struct ggml_tensor * probs_k = ggml_view_2d(ctx,
probs,
probs->ne[0],
probs->ne[1],
probs->nb[1],
k*probs->nb[2]);
sample_softmax(logits_k, probs_k, best_samples_k);
}
}
void print_row(struct ggml_tensor * probs, int i) {
for (int k = 0; k < probs->ne[0]; ++k) {
float p = ggml_get_f32_1d(probs, i*probs->ne[0] + k);
@ -1071,6 +1426,30 @@ void get_example_targets(int example_id, struct ggml_tensor * tokens_input, stru
}
}
void get_example_targets_batch(struct ggml_context * ctx, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * targets) {
GGML_ASSERT(tokens_input->n_dims == 2);
GGML_ASSERT( targets->n_dims == 3);
int n_tokens = tokens_input->ne[0];
int n_batch = tokens_input->ne[1];
int n_vocab = targets->ne[0];
GGML_ASSERT(n_tokens == targets->ne[1]);
GGML_ASSERT(n_batch == targets->ne[2]);
for (int k=0; k<n_batch; ++k) {
struct ggml_tensor * tokens_input_k = ggml_view_1d(ctx,
tokens_input,
tokens_input->ne[0],
k*tokens_input->nb[1]);
struct ggml_tensor * targets_k = ggml_view_2d(ctx,
targets,
targets->ne[0],
targets->ne[1],
targets->nb[1],
k*targets->nb[2]);
get_example_targets(example_id*n_batch + k, tokens_input_k, targets_k);
}
}
void lshift_examples(struct ggml_tensor * tokens_input, struct ggml_tensor * targets, int n_shift) {
int n_tokens = tokens_input->ne[0];
int n_vocab = targets->ne[0];
@ -1162,12 +1541,12 @@ int main(int argc, char ** argv) {
randomize_model_lora(&model_lora, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
*/
int n_batch = 8;
// key + value cache for the self attention
struct llama_kv_cache kv_self;
printf("init_kv_cache\n");
kv_self.ctx = model.ctx;
init_kv_cache(&kv_self, &model);
init_kv_cache(&kv_self, &model, n_batch);
//init_kv_cache_lora(&kv_self, &model_lora);
size_t compute_size = 1024ll*1024ll*1024ll;
@ -1187,16 +1566,16 @@ int main(int argc, char ** argv) {
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_tensor * before_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * before_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens);
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens);
struct ggml_tensor * tokens_input1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * tokens_input2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * before_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * before_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * tokens_input1 = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * tokens_input2 = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
// struct ggml_tensor * tokens_input3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
// struct ggml_tensor * tokens_input4 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * targets1 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens);
struct ggml_tensor * targets2 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens);
struct ggml_tensor * targets1 = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * targets2 = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
// struct ggml_tensor * targets3 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens);
// struct ggml_tensor * targets4 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens);
@ -1205,24 +1584,24 @@ int main(int argc, char ** argv) {
ggml_cgraph gf = {};
gf.n_threads = 1;
get_example_targets(64*ex+0, tokens_input1, targets1);
get_example_targets(64*ex+16, tokens_input2, targets2);
get_example_targets_batch(ctx0, 64*ex+0, tokens_input1, targets1);
// get_example_targets_batch(64*ex+16, tokens_input2, targets2);
// get_example_targets(64*ex+32, tokens_input3, targets3);
// get_example_targets(64*ex+48, tokens_input4, targets4);
// print_matrix(targets);
// print_tokens(tokens_input, n_vocab);
struct ggml_tensor * logits1 = forward(&model, &kv_self, ctx0, &gf, tokens_input1, n_tokens, n_past);
struct ggml_tensor * logits2 = forward(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past);
struct ggml_tensor * logits1 = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input1, n_tokens, n_past, n_batch);
// struct ggml_tensor * logits2 = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past, n_batch);
// struct ggml_tensor * logits3 = forward(&model, &kv_self, ctx0, &gf, tokens_input3, n_tokens, n_past);
// struct ggml_tensor * logits4 = forward(&model, &kv_self, ctx0, &gf, tokens_input4, n_tokens, n_past);
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets1, logits1);
// struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1);
struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1);
struct ggml_tensor * e = ggml_add(ctx0,
square_error_loss(ctx0, targets1, logits1),
square_error_loss(ctx0, targets2, logits2));
// struct ggml_tensor * e = ggml_add(ctx0,
// square_error_loss(ctx0, targets1, logits1),
// square_error_loss(ctx0, targets2, logits2));
// struct ggml_tensor * e = ggml_add(ctx0,
// cross_entropy_loss(ctx0, targets1, logits1),
// cross_entropy_loss(ctx0, targets2, logits2));
@ -1269,7 +1648,7 @@ int main(int argc, char ** argv) {
}
if (ex % 64 == 0) {
sample_softmax(logits1, after_opt_probs, after_opt_best_samples);
sample_softmax_batch(ctx0, logits1, after_opt_probs, after_opt_best_samples);
// printf("probabilities after optimization:\n");
// print_matrix(after_opt_probs);
printf("best samples after optimization:\n");