switching from training with adam to lbfgs produces much better results in the baby-llama example

This commit is contained in:
xaedes 2023-05-01 21:01:17 +02:00
parent 29a0f8b940
commit 5f23052eb2
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 11 additions and 9 deletions

View file

@ -566,11 +566,12 @@ int main(int argc, char ** argv) {
lcparams.no_alloc = false;
struct llama_model model;
model.hparams.n_vocab = 8;
model.hparams.n_vocab = 16;
model.hparams.n_ctx = 64;
model.hparams.n_embd = 64;
model.hparams.n_mult = 2;
model.hparams.n_head = 8;
model.hparams.n_layer = 4;
model.hparams.n_layer = 16;
model.hparams.n_rot = 16;
model.ctx = ggml_init(lcparams);
printf("init model\n");
@ -605,18 +606,17 @@ int main(int argc, char ** argv) {
float z = (y+1.0f)*0.5f;
int token = (int)(z*(float)(model.hparams.n_vocab-1));
for (int k = 0; k < token; ++k) {
printf(" ");
ggml_set_f32_1d(targets, i*model.hparams.n_vocab + k, 0.0f);
}
printf("X");
ggml_set_f32_1d(targets, i*model.hparams.n_vocab + token, +1.0f);
for (int k = token+1; k < model.hparams.n_vocab; ++k) {
printf(" ");
ggml_set_f32_1d(targets, i*model.hparams.n_vocab + k, 0.0f);
}
printf("\n");
ggml_set_i32_1d(tokens_input, i, token);
}
print_probs(targets);
print_tokens(tokens_input, model.hparams.n_vocab);
int n_past = 0;
ggml_cgraph gf = {};
@ -637,8 +637,10 @@ int main(int argc, char ** argv) {
printf("best samples before optimization:\n");
print_tokens(before_opt_best_samples, model.hparams.n_vocab);
struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
ggml_opt(ctx0, opt_params, e);
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
ggml_opt(ctx0, opt_params_lbfgs, e);
// ggml_opt(ctx0, opt_params_adam, e);
//
ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf);

2
ggml.h
View file

@ -192,7 +192,7 @@
#define GGML_MAX_DIMS 4
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 32
#define GGML_MAX_PARAMS 256
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_OPT 4
#define GGML_DEFAULT_N_THREADS 4