switching from training with adam to lbfgs produces much better results in the baby-llama example
This commit is contained in:
parent
29a0f8b940
commit
5f23052eb2
2 changed files with 11 additions and 9 deletions
|
@ -566,11 +566,12 @@ int main(int argc, char ** argv) {
|
|||
lcparams.no_alloc = false;
|
||||
|
||||
struct llama_model model;
|
||||
model.hparams.n_vocab = 8;
|
||||
model.hparams.n_vocab = 16;
|
||||
model.hparams.n_ctx = 64;
|
||||
model.hparams.n_embd = 64;
|
||||
model.hparams.n_mult = 2;
|
||||
model.hparams.n_head = 8;
|
||||
model.hparams.n_layer = 4;
|
||||
model.hparams.n_layer = 16;
|
||||
model.hparams.n_rot = 16;
|
||||
model.ctx = ggml_init(lcparams);
|
||||
printf("init model\n");
|
||||
|
@ -605,18 +606,17 @@ int main(int argc, char ** argv) {
|
|||
float z = (y+1.0f)*0.5f;
|
||||
int token = (int)(z*(float)(model.hparams.n_vocab-1));
|
||||
for (int k = 0; k < token; ++k) {
|
||||
printf(" ");
|
||||
ggml_set_f32_1d(targets, i*model.hparams.n_vocab + k, 0.0f);
|
||||
}
|
||||
printf("X");
|
||||
ggml_set_f32_1d(targets, i*model.hparams.n_vocab + token, +1.0f);
|
||||
for (int k = token+1; k < model.hparams.n_vocab; ++k) {
|
||||
printf(" ");
|
||||
ggml_set_f32_1d(targets, i*model.hparams.n_vocab + k, 0.0f);
|
||||
}
|
||||
printf("\n");
|
||||
ggml_set_i32_1d(tokens_input, i, token);
|
||||
}
|
||||
print_probs(targets);
|
||||
print_tokens(tokens_input, model.hparams.n_vocab);
|
||||
|
||||
int n_past = 0;
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
|
@ -637,8 +637,10 @@ int main(int argc, char ** argv) {
|
|||
printf("best samples before optimization:\n");
|
||||
print_tokens(before_opt_best_samples, model.hparams.n_vocab);
|
||||
|
||||
struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM);
|
||||
ggml_opt(ctx0, opt_params, e);
|
||||
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
|
||||
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
|
||||
ggml_opt(ctx0, opt_params_lbfgs, e);
|
||||
// ggml_opt(ctx0, opt_params_adam, e);
|
||||
//
|
||||
ggml_build_forward_expand(&gf, e);
|
||||
ggml_graph_compute(ctx0, &gf);
|
||||
|
|
2
ggml.h
2
ggml.h
|
@ -192,7 +192,7 @@
|
|||
|
||||
#define GGML_MAX_DIMS 4
|
||||
#define GGML_MAX_NODES 4096
|
||||
#define GGML_MAX_PARAMS 32
|
||||
#define GGML_MAX_PARAMS 256
|
||||
#define GGML_MAX_CONTEXTS 64
|
||||
#define GGML_MAX_OPT 4
|
||||
#define GGML_DEFAULT_N_THREADS 4
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue