cleanup code for batched training
This commit is contained in:
parent
3e3ed9560c
commit
581e5eb954
1 changed files with 9 additions and 53 deletions
|
@ -1566,63 +1566,26 @@ int main(int argc, char ** argv) {
|
|||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
|
||||
struct ggml_tensor * before_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||
struct ggml_tensor * before_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||
struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||
struct ggml_tensor * tokens_input1 = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||
struct ggml_tensor * tokens_input2 = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||
// struct ggml_tensor * tokens_input3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
// struct ggml_tensor * tokens_input4 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
||||
struct ggml_tensor * targets1 = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||
struct ggml_tensor * targets2 = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||
// struct ggml_tensor * targets3 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens);
|
||||
// struct ggml_tensor * targets4 = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens);
|
||||
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||
struct ggml_tensor * targets = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||
|
||||
int n_past = 0;
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = 1;
|
||||
|
||||
get_example_targets_batch(ctx0, 64*ex+0, tokens_input1, targets1);
|
||||
// get_example_targets_batch(64*ex+16, tokens_input2, targets2);
|
||||
// get_example_targets(64*ex+32, tokens_input3, targets3);
|
||||
// get_example_targets(64*ex+48, tokens_input4, targets4);
|
||||
// print_matrix(targets);
|
||||
// print_tokens(tokens_input, n_vocab);
|
||||
get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets);
|
||||
|
||||
struct ggml_tensor * logits1 = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input1, n_tokens, n_past, n_batch);
|
||||
// struct ggml_tensor * logits2 = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input2, n_tokens, n_past, n_batch);
|
||||
// struct ggml_tensor * logits3 = forward(&model, &kv_self, ctx0, &gf, tokens_input3, n_tokens, n_past);
|
||||
// struct ggml_tensor * logits4 = forward(&model, &kv_self, ctx0, &gf, tokens_input4, n_tokens, n_past);
|
||||
|
||||
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets1, logits1);
|
||||
struct ggml_tensor * e = square_error_loss(ctx0, targets1, logits1);
|
||||
|
||||
// struct ggml_tensor * e = ggml_add(ctx0,
|
||||
// square_error_loss(ctx0, targets1, logits1),
|
||||
// square_error_loss(ctx0, targets2, logits2));
|
||||
// struct ggml_tensor * e = ggml_add(ctx0,
|
||||
// cross_entropy_loss(ctx0, targets1, logits1),
|
||||
// cross_entropy_loss(ctx0, targets2, logits2));
|
||||
// struct ggml_tensor * e = ggml_add(ctx0,
|
||||
// ggml_add(ctx0,
|
||||
// cross_entropy_loss(ctx0, targets1, logits1),
|
||||
// cross_entropy_loss(ctx0, targets2, logits2)),
|
||||
// ggml_add(ctx0,
|
||||
// cross_entropy_loss(ctx0, targets3, logits3),
|
||||
// cross_entropy_loss(ctx0, targets4, logits4)));
|
||||
struct ggml_tensor * logits = forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
// struct ggml_tensor * e = cross_entropy_loss(ctx0, targets, logits);
|
||||
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
|
||||
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
|
||||
float error_before_opt = ggml_get_f32_1d(e, 0);
|
||||
// sample_softmax(logits1, before_opt_probs, before_opt_best_samples);
|
||||
|
||||
// printf("probabilities before optimization:\n");
|
||||
// print_matrix(before_opt_probs);
|
||||
// printf("best samples before optimization:\n");
|
||||
// print_tokens(before_opt_best_samples, n_vocab);
|
||||
|
||||
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
|
||||
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
|
||||
|
@ -1632,15 +1595,14 @@ int main(int argc, char ** argv) {
|
|||
opt_params_lbfgs.print_backward_graph = false;
|
||||
opt_params_adam.adam.n_iter = 16;
|
||||
opt_params_lbfgs.lbfgs.n_iter = 16;
|
||||
ggml_opt(ctx0, opt_params_adam, e);
|
||||
// ggml_opt(ctx0, opt_params_lbfgs, e);
|
||||
// ggml_opt(ctx0, opt_params_adam, e);
|
||||
ggml_opt(ctx0, opt_params_lbfgs, e);
|
||||
//
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
|
||||
float error_after_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
|
||||
if (ex % 8 == 0) {
|
||||
printf("Example %d\n", (ex+1));
|
||||
printf("error_before_opt: %.2f\n", error_before_opt);
|
||||
|
@ -1648,7 +1610,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if (ex % 64 == 0) {
|
||||
sample_softmax_batch(ctx0, logits1, after_opt_probs, after_opt_best_samples);
|
||||
sample_softmax_batch(ctx0, logits, after_opt_probs, after_opt_best_samples);
|
||||
// printf("probabilities after optimization:\n");
|
||||
// print_matrix(after_opt_probs);
|
||||
printf("best samples after optimization:\n");
|
||||
|
@ -1708,12 +1670,6 @@ int main(int argc, char ** argv) {
|
|||
ggml_set_i32_1d(tokens_input, 0, 0);
|
||||
ggml_set_i32_1d(tokens_input, sample_ctx-1, token);
|
||||
|
||||
// printf("---\n");
|
||||
// for (int i=0; i<sample_ctx-1; ++i) {
|
||||
// print_token(ggml_get_i32_1d(tokens_input, i), model.hparams.n_vocab);
|
||||
// }
|
||||
// printf("--\n");
|
||||
|
||||
ggml_free(ctx0);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue