use different arguments for input and output checkpoint

This commit is contained in:
xaedes 2023-05-19 18:34:18 +02:00
parent d8b0666429
commit 44d83558bc
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1433,7 +1433,7 @@ void save_model(struct my_llama_model * model, const char * filename) {
}
}
void load_model(struct my_llama_model * model, const char * filename, bool init) {
bool load_model(struct my_llama_model * model, const char * filename, bool init) {
struct llama_file file(filename, "rb");
if (file.fp) {
@ -1474,24 +1474,28 @@ void load_model(struct my_llama_model * model, const char * filename, bool init)
read_tensor(&file, layer.w3);
}
}
return (file.fp != NULL);
}
int main(int argc, char ** argv) {
const char * default_model = "ggml-vic7b-uncensored-q4_0.bin";
const char * default_train = "shakespeare.txt";
const char * default_checkpoint = "checkpoint.bin";
const char * default_argv[4] = {argv[0], default_model, default_train, default_checkpoint};
const char * default_chkpt_in = "checkpoint.bin";
const char * default_chkpt_out = "checkpoint.bin";
const char * default_argv[5] = {argv[0], default_model, default_train, default_chkpt_in, default_chkpt_out};
if (argc < 4) {
fprintf(stderr, "usage: %s model training_data\n", argv[0]);
if (argc < 5) {
fprintf(stderr, "usage: %s model training_data chkpt_in chkpt_out\n", argv[0]);
//return 1;
}
srand(time(NULL));
const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1];
const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2];
const char * fn_chkpt = (argc >= 4) ? argv[3] : default_argv[3];
const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1];
const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2];
const char * fn_chkpt_in = (argc >= 4) ? argv[3] : default_argv[3];
const char * fn_chkpt_out = (argc >= 5) ? argv[4] : default_argv[4];
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true;
@ -1516,17 +1520,20 @@ int main(int argc, char ** argv) {
print_params(&model.hparams);
std::vector<bool> token_occurs;
std::vector<bool> token_notavail;
token_occurs.resize(model.hparams.n_vocab, false);
std::vector<size_t> token_noccurs;
std::vector<bool> token_notavail;
token_noccurs.resize(model.hparams.n_vocab, 0);
token_notavail.resize(model.hparams.n_vocab, true);
for (int i=0; i<train_tokens.size(); ++i) {
token_occurs[train_tokens[i]] = true;
++token_noccurs[train_tokens[i]];
token_notavail[train_tokens[i]] = false;
}
std::vector<float> token_freq;
token_freq.resize(model.hparams.n_vocab, 0);
int n_unique_tokens = 0;
for (int i=0; i<token_occurs.size(); ++i) {
n_unique_tokens += token_occurs[i] ? 1 : 0;
for (int i=0; i<token_noccurs.size(); ++i) {
token_freq[i] = (float) token_noccurs[i] / (float) train_tokens.size();
n_unique_tokens += (token_noccurs[i] > 0) ? 1 : 0;
}
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
@ -1545,9 +1552,12 @@ int main(int argc, char ** argv) {
my_llama_sampler sampler;
printf("%s: init model\n", __func__);
load_model(&model, fn_chkpt, true);
bool existed = load_model(&model, fn_chkpt_in, true);
bool from_scratch = !existed;
set_param_model(&model);
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
if (from_scratch) {
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
}
init_kv_cache(&kv_self, &model, n_batch);
init_sampler(&sampler, lctx);
@ -1559,10 +1569,12 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
bool samples_start_after_nl = false;
std::vector<int> train_samples;
train_samples.push_back(0);
for (int i=1; i<train_tokens.size()-n_tokens; ++i) {
if (train_tokens[i-1] == llama_token_nl()) {
if (!samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) {
train_samples.push_back(i);
}
}
@ -1674,7 +1686,7 @@ int main(int argc, char ** argv) {
ggml_free(ctx0);
}
save_model(&model, fn_chkpt);
save_model(&model, fn_chkpt_out);
{
int n_gen = 128;