diff --git a/examples/baby-llama/baby-llama-text.cpp b/examples/baby-llama/baby-llama-text.cpp index aa8c3ace4..5c019f4bb 100644 --- a/examples/baby-llama/baby-llama-text.cpp +++ b/examples/baby-llama/baby-llama-text.cpp @@ -1433,7 +1433,7 @@ void save_model(struct my_llama_model * model, const char * filename) { } } -void load_model(struct my_llama_model * model, const char * filename, bool init) { +bool load_model(struct my_llama_model * model, const char * filename, bool init) { struct llama_file file(filename, "rb"); if (file.fp) { @@ -1474,24 +1474,28 @@ void load_model(struct my_llama_model * model, const char * filename, bool init) read_tensor(&file, layer.w3); } } + + return (file.fp != NULL); } int main(int argc, char ** argv) { const char * default_model = "ggml-vic7b-uncensored-q4_0.bin"; const char * default_train = "shakespeare.txt"; - const char * default_checkpoint = "checkpoint.bin"; - const char * default_argv[4] = {argv[0], default_model, default_train, default_checkpoint}; + const char * default_chkpt_in = "checkpoint.bin"; + const char * default_chkpt_out = "checkpoint.bin"; + const char * default_argv[5] = {argv[0], default_model, default_train, default_chkpt_in, default_chkpt_out}; - if (argc < 4) { - fprintf(stderr, "usage: %s model training_data\n", argv[0]); + if (argc < 5) { + fprintf(stderr, "usage: %s model training_data chkpt_in chkpt_out\n", argv[0]); //return 1; } srand(time(NULL)); - const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1]; - const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2]; - const char * fn_chkpt = (argc >= 4) ? argv[3] : default_argv[3]; + const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1]; + const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2]; + const char * fn_chkpt_in = (argc >= 4) ? argv[3] : default_argv[3]; + const char * fn_chkpt_out = (argc >= 5) ? argv[4] : default_argv[4]; struct llama_context_params llama_params = llama_context_default_params(); llama_params.vocab_only = true; @@ -1516,17 +1520,20 @@ int main(int argc, char ** argv) { print_params(&model.hparams); - std::vector token_occurs; - std::vector token_notavail; - token_occurs.resize(model.hparams.n_vocab, false); + std::vector token_noccurs; + std::vector token_notavail; + token_noccurs.resize(model.hparams.n_vocab, 0); token_notavail.resize(model.hparams.n_vocab, true); for (int i=0; i token_freq; + token_freq.resize(model.hparams.n_vocab, 0); int n_unique_tokens = 0; - for (int i=0; i 0) ? 1 : 0; } printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens); @@ -1545,9 +1552,12 @@ int main(int argc, char ** argv) { my_llama_sampler sampler; printf("%s: init model\n", __func__); - load_model(&model, fn_chkpt, true); + bool existed = load_model(&model, fn_chkpt_in, true); + bool from_scratch = !existed; set_param_model(&model); - randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f); + if (from_scratch) { + randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f); + } init_kv_cache(&kv_self, &model, n_batch); init_sampler(&sampler, lctx); @@ -1559,10 +1569,12 @@ int main(int argc, char ** argv) { int n_tokens = model.hparams.n_ctx; int n_vocab = model.hparams.n_vocab; + bool samples_start_after_nl = false; + std::vector train_samples; train_samples.push_back(0); for (int i=1; i