diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 9feafae98..c50baf470 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -566,11 +566,12 @@ int main(int argc, char ** argv) { lcparams.no_alloc = false; struct llama_model model; - model.hparams.n_vocab = 8; + model.hparams.n_vocab = 16; model.hparams.n_ctx = 64; model.hparams.n_embd = 64; + model.hparams.n_mult = 2; model.hparams.n_head = 8; - model.hparams.n_layer = 4; + model.hparams.n_layer = 16; model.hparams.n_rot = 16; model.ctx = ggml_init(lcparams); printf("init model\n"); @@ -605,18 +606,17 @@ int main(int argc, char ** argv) { float z = (y+1.0f)*0.5f; int token = (int)(z*(float)(model.hparams.n_vocab-1)); for (int k = 0; k < token; ++k) { - printf(" "); ggml_set_f32_1d(targets, i*model.hparams.n_vocab + k, 0.0f); } - printf("X"); ggml_set_f32_1d(targets, i*model.hparams.n_vocab + token, +1.0f); for (int k = token+1; k < model.hparams.n_vocab; ++k) { - printf(" "); ggml_set_f32_1d(targets, i*model.hparams.n_vocab + k, 0.0f); } - printf("\n"); ggml_set_i32_1d(tokens_input, i, token); } + print_probs(targets); + print_tokens(tokens_input, model.hparams.n_vocab); + int n_past = 0; ggml_cgraph gf = {}; @@ -637,8 +637,10 @@ int main(int argc, char ** argv) { printf("best samples before optimization:\n"); print_tokens(before_opt_best_samples, model.hparams.n_vocab); - struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_ADAM); - ggml_opt(ctx0, opt_params, e); + struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); + struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS); + ggml_opt(ctx0, opt_params_lbfgs, e); + // ggml_opt(ctx0, opt_params_adam, e); // ggml_build_forward_expand(&gf, e); ggml_graph_compute(ctx0, &gf); diff --git a/ggml.h b/ggml.h index d14af5c59..832281566 100644 --- a/ggml.h +++ b/ggml.h @@ -192,7 +192,7 @@ #define GGML_MAX_DIMS 4 #define GGML_MAX_NODES 4096 -#define GGML_MAX_PARAMS 32 +#define GGML_MAX_PARAMS 256 #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_OPT 4 #define GGML_DEFAULT_N_THREADS 4