use different arguments for input and output checkpoint

This commit is contained in:
xaedes 2023-05-19 18:34:18 +02:00
parent d8b0666429
commit 44d83558bc
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1433,7 +1433,7 @@ void save_model(struct my_llama_model * model, const char * filename) {
} }
} }
void load_model(struct my_llama_model * model, const char * filename, bool init) { bool load_model(struct my_llama_model * model, const char * filename, bool init) {
struct llama_file file(filename, "rb"); struct llama_file file(filename, "rb");
if (file.fp) { if (file.fp) {
@ -1474,16 +1474,19 @@ void load_model(struct my_llama_model * model, const char * filename, bool init)
read_tensor(&file, layer.w3); read_tensor(&file, layer.w3);
} }
} }
return (file.fp != NULL);
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
const char * default_model = "ggml-vic7b-uncensored-q4_0.bin"; const char * default_model = "ggml-vic7b-uncensored-q4_0.bin";
const char * default_train = "shakespeare.txt"; const char * default_train = "shakespeare.txt";
const char * default_checkpoint = "checkpoint.bin"; const char * default_chkpt_in = "checkpoint.bin";
const char * default_argv[4] = {argv[0], default_model, default_train, default_checkpoint}; const char * default_chkpt_out = "checkpoint.bin";
const char * default_argv[5] = {argv[0], default_model, default_train, default_chkpt_in, default_chkpt_out};
if (argc < 4) { if (argc < 5) {
fprintf(stderr, "usage: %s model training_data\n", argv[0]); fprintf(stderr, "usage: %s model training_data chkpt_in chkpt_out\n", argv[0]);
//return 1; //return 1;
} }
@ -1491,7 +1494,8 @@ int main(int argc, char ** argv) {
const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1]; const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1];
const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2]; const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2];
const char * fn_chkpt = (argc >= 4) ? argv[3] : default_argv[3]; const char * fn_chkpt_in = (argc >= 4) ? argv[3] : default_argv[3];
const char * fn_chkpt_out = (argc >= 5) ? argv[4] : default_argv[4];
struct llama_context_params llama_params = llama_context_default_params(); struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true; llama_params.vocab_only = true;
@ -1516,17 +1520,20 @@ int main(int argc, char ** argv) {
print_params(&model.hparams); print_params(&model.hparams);
std::vector<bool> token_occurs; std::vector<size_t> token_noccurs;
std::vector<bool> token_notavail; std::vector<bool> token_notavail;
token_occurs.resize(model.hparams.n_vocab, false); token_noccurs.resize(model.hparams.n_vocab, 0);
token_notavail.resize(model.hparams.n_vocab, true); token_notavail.resize(model.hparams.n_vocab, true);
for (int i=0; i<train_tokens.size(); ++i) { for (int i=0; i<train_tokens.size(); ++i) {
token_occurs[train_tokens[i]] = true; ++token_noccurs[train_tokens[i]];
token_notavail[train_tokens[i]] = false; token_notavail[train_tokens[i]] = false;
} }
std::vector<float> token_freq;
token_freq.resize(model.hparams.n_vocab, 0);
int n_unique_tokens = 0; int n_unique_tokens = 0;
for (int i=0; i<token_occurs.size(); ++i) { for (int i=0; i<token_noccurs.size(); ++i) {
n_unique_tokens += token_occurs[i] ? 1 : 0; token_freq[i] = (float) token_noccurs[i] / (float) train_tokens.size();
n_unique_tokens += (token_noccurs[i] > 0) ? 1 : 0;
} }
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens); printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
@ -1545,9 +1552,12 @@ int main(int argc, char ** argv) {
my_llama_sampler sampler; my_llama_sampler sampler;
printf("%s: init model\n", __func__); printf("%s: init model\n", __func__);
load_model(&model, fn_chkpt, true); bool existed = load_model(&model, fn_chkpt_in, true);
bool from_scratch = !existed;
set_param_model(&model); set_param_model(&model);
if (from_scratch) {
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f); randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
}
init_kv_cache(&kv_self, &model, n_batch); init_kv_cache(&kv_self, &model, n_batch);
init_sampler(&sampler, lctx); init_sampler(&sampler, lctx);
@ -1559,10 +1569,12 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx; int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab; int n_vocab = model.hparams.n_vocab;
bool samples_start_after_nl = false;
std::vector<int> train_samples; std::vector<int> train_samples;
train_samples.push_back(0); train_samples.push_back(0);
for (int i=1; i<train_tokens.size()-n_tokens; ++i) { for (int i=1; i<train_tokens.size()-n_tokens; ++i) {
if (train_tokens[i-1] == llama_token_nl()) { if (!samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) {
train_samples.push_back(i); train_samples.push_back(i);
} }
} }
@ -1674,7 +1686,7 @@ int main(int argc, char ** argv) {
ggml_free(ctx0); ggml_free(ctx0);
} }
save_model(&model, fn_chkpt); save_model(&model, fn_chkpt_out);
{ {
int n_gen = 128; int n_gen = 128;