optimize loss over multiple samples

this increases computation graph, need parallel batched forward for more efficiency.
This commit is contained in:
xaedes 2023-05-07 01:23:51 +02:00
parent 7a5dec24f8
commit 226521a4f1
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -681,27 +681,54 @@ int main(int argc, char ** argv) {
// struct ggml_tensor * before_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * after_opt_probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
struct ggml_tensor * tokens_input = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * targets = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
struct ggml_tensor * tokens_input1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * tokens_input2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
// struct ggml_tensor * tokens_input3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
// struct ggml_tensor * tokens_input4 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
struct ggml_tensor * targets1 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
struct ggml_tensor * targets2 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
// struct ggml_tensor * targets3 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
// struct ggml_tensor * targets4 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_vocab, n_tokens);
int n_past = 0;
ggml_cgraph gf = {};
gf.n_threads = 1;
get_example_targets(ex, tokens_input, targets);
printf("Example %d\n", (ex+1));
get_example_targets(64*ex+0, tokens_input1, targets1);
get_example_targets(64*ex+16, tokens_input2, targets2);
// get_example_targets(64*ex+32, tokens_input3, targets3);
// get_example_targets(64*ex+48, tokens_input4, targets4);
// print_probs(targets);
// print_tokens(tokens_input, model.hparams.n_vocab);
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past);
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
struct ggml_tensor * logits1 = forward(&model, &kv_self, ctx0, &gf, tokens_input1, n_tokens, n_past);
struct ggml_tensor * logits2 = forward(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past);
// struct ggml_tensor * logits3 = forward(&model, &kv_self, ctx0, &gf, tokens_input3, n_tokens, n_past);
// struct ggml_tensor * logits4 = forward(&model, &kv_self, ctx0, &gf, tokens_input4, n_tokens, n_past);
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets1, logits1);
// struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1);
struct ggml_tensor * e = ggml_add(ctx0,
square_error_loss(ctx0, targets1, logits1),
square_error_loss(ctx0, targets2, logits2));
// struct ggml_tensor * e = ggml_add(ctx0,
// cross_entropy_loss(ctx0, targets1, logits1),
// cross_entropy_loss(ctx0, targets2, logits2));
// struct ggml_tensor * e = ggml_add(ctx0,
// ggml_add(ctx0,
// cross_entropy_loss(ctx0, targets1, logits1),
// cross_entropy_loss(ctx0, targets2, logits2)),
// ggml_add(ctx0,
// cross_entropy_loss(ctx0, targets3, logits3),
// cross_entropy_loss(ctx0, targets4, logits4)));
ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf);
float error_before_opt = ggml_get_f32_1d(e, 0);
// sample_softmax(logits, before_opt_probs, before_opt_best_samples);
// sample_softmax(logits1, before_opt_probs, before_opt_best_samples);
// printf("probabilities before optimization:\n");
// print_probs(before_opt_probs);
@ -732,7 +759,7 @@ int main(int argc, char ** argv) {
}
if (ex % 64 == 0) {
sample_softmax(logits, after_opt_probs, after_opt_best_samples);
sample_softmax(logits1, after_opt_probs, after_opt_best_samples);
// printf("probabilities after optimization:\n");
// print_probs(after_opt_probs);
printf("best samples after optimization:\n");
@ -804,6 +831,6 @@ int main(int argc, char ** argv) {
printf("done\n");
// ggml_free(kv_self.ctx);
// ggml_free(model.ctx);
ggml_free(model.ctx);
return 0;
}