add train_params and command line option parser

This commit is contained in:
xaedes 2023-05-30 15:53:55 +02:00
parent fcbc4457d6
commit ec8e262d1d
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -2150,66 +2150,341 @@ float cosine_decay_restart(int decay_steps, const float alpha, int step, float r
return cosine_decay(decay_steps, alpha, step);
}
int main(int argc, char ** argv) {
const char * default_model = "ggml-vic7b-uncensored-q4_0.bin";
const char * default_train = "shakespeare.txt";
const char * default_chkpt_in = "checkpoint.bin";
const char * default_chkpt_out = "checkpoint.bin";
const char * default_model_out = "ggml-checkpoint-f32.bin";
const char * default_argv[6] = {argv[0], default_model, default_train, default_chkpt_in, default_chkpt_out, default_model_out};
struct train_params {
const char * fn_vocab_model;
const char * fn_train_data;
const char * fn_checkpoint_in;
const char * fn_checkpoint_out;
const char * fn_model_out;
if (argc < 6) {
fprintf(stderr, "usage: %s model training_data chkpt_in chkpt_out model_out\n", argv[0]);
//return 1;
}
int seed;
int n_ctx;
int n_embd;
int n_mult;
int n_head;
int n_layer;
int n_rotmax;
int seed = 1;
int n_ctx = 256;
// int n_ctx = 64;
int n_embd = 256;
int n_mult = 256;
int n_head = 8;
int n_layer = 16;
int n_rotmax = 64;
int n_threads;
int n_batch;
int n_examples;
int n_predict;
int n_threads = 6;
int n_batch = 8;
int n_examples = 32;
int print_info_interval;
int print_details_interval;
int print_info_interval = 1;
int print_details_interval = 2;
bool samples_start_after_nl = false;
bool use_adam = true;
bool use_flash = false;
bool samples_start_after_nl;
bool use_adam;
bool use_flash;
// only adam
int warmup = 100;
int cos_decay_steps = 1000;
float cos_decay_restart = 1.1f;
float cos_decay_alpha = 0.0f;
int warmup;
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_alpha;
int lbfgs_n_iter = 16;
int adam_n_iter = 16;
float adam_alpha = 1e-3;
float adam_decay = 1e-3;
int lbfgs_n_iter;
int adam_n_iter;
float adam_alpha;
float adam_decay;
};
if (seed < 0) {
srand(time(NULL));
} else {
srand(seed);
struct train_params get_default_train_params() {
struct train_params params;
params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
params.fn_train_data = "shakespeare.txt";
params.fn_checkpoint_in = "checkpoint.bin";
params.fn_checkpoint_out = "checkpoint.bin";
params.fn_model_out = "ggml-checkpoint-f32.bin";
params.seed = -1;
params.n_ctx = 128;
params.n_embd = 256;
params.n_mult = 256;
params.n_head = 8;
params.n_layer = 16;
params.n_rotmax = 64;
params.n_threads = 6;
params.n_batch = 8;
params.n_examples = 8;
params.n_predict = 1024;
params.print_info_interval = 1;
params.print_details_interval = 2;
params.samples_start_after_nl = false;
params.use_adam = true;
params.use_flash = true;
// only adam
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_alpha = 0.0f;
params.lbfgs_n_iter = 16;
params.adam_n_iter = 16;
params.adam_alpha = 1e-3;
params.adam_decay = 1e-3;
return params;
}
void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
fprintf(stderr, " --mult N Mult size used for new models, influences feedforward size. (default %d)\n", params->n_mult);
fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer);
fprintf(stderr, " --rotmax N Maximal number Rope dimensions for new models (default %d)\n", params->n_rotmax);
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
fprintf(stderr, " --predict N Number of tokens to generate after training (default %d)\n", params->n_predict);
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
fprintf(stderr, " --print-details-interval N Print details during training each N examples (default %d)\n", params->print_details_interval);
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
fprintf(stderr, " --no-flash Don't use flash attention.\n");
fprintf(stderr, " --use-flash Use flash attention (default)\n");
fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup);
fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-alpha N Cosine decay alpha (default %f)\n", params->cos_decay_alpha);
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
fprintf(stderr, "\n");
}
bool train_params_parse(int argc, char ** argv, struct train_params * params) {
bool invalid_param = false;
std::string arg;
struct train_params default_params = get_default_train_params();
const std::string arg_prefix = "--";
for (int i = 1; i < argc; i++) {
arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (arg == "--vocab-model") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_vocab_model = argv[i];
} else if (arg == "--train-data") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_train_data = argv[i];
} else if (arg == "--checkpoint-in") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_in = argv[i];
} else if (arg == "--checkpoint-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_checkpoint_out = argv[i];
} else if (arg == "--model-out") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_model_out = argv[i];
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->seed = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_ctx = std::stoi(argv[i]);
} else if (arg == "--embd") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_embd = std::stoi(argv[i]);
} else if (arg == "--mult") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_mult = std::stoi(argv[i]);
} else if (arg == "--head") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_head = std::stoi(argv[i]);
} else if (arg == "--layer") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_layer = std::stoi(argv[i]);
} else if (arg == "--rotmax") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_rotmax = std::stoi(argv[i]);
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_threads = std::stoi(argv[i]);
} else if (arg == "-b" || arg == "--batch") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_batch = std::stoi(argv[i]);
} else if (arg == "-n" || arg == "--examples") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_examples = std::stoi(argv[i]);
} else if (arg == "--predict") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->n_predict = std::stoi(argv[i]);
} else if (arg == "--print-info-interval") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->print_info_interval = std::stoi(argv[i]);
} else if (arg == "--print-details-interval") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->print_details_interval = std::stoi(argv[i]);
} else if (arg == "--samples-after-nl") {
params->samples_start_after_nl = true;
} else if (arg == "--use-lbfgs") {
params->use_adam = false;
} else if (arg == "--use-adam") {
params->use_adam = true;
} else if (arg == "--no-flash") {
params->use_flash = false;
} else if (arg == "--use-flash") {
params->use_flash = true;
} else if (arg == "--warmup") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->warmup = std::stoi(argv[i]);
} else if (arg == "--cos-decay-steps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_steps = std::stof(argv[i]);
} else if (arg == "--cos-decay-restart") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_restart = std::stof(argv[i]);
} else if (arg == "--cos-decay-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->cos_decay_alpha = std::stof(argv[i]);
} else if (arg == "--lbfgs-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->lbfgs_n_iter = std::stoi(argv[i]);
} else if (arg == "--adam-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_n_iter = std::stoi(argv[i]);
} else if (arg == "--adam-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_alpha = std::stof(argv[i]);
} else if (arg == "--adam-decay") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_decay = std::stof(argv[i]);
} else if (arg == "-h" || arg == "--help") {
train_print_usage(argc, argv, &default_params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params);
exit(1);
}
}
if (invalid_param) {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
train_print_usage(argc, argv, &default_params);
exit(1);
}
const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1];
const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2];
const char * fn_chkpt_in = (argc >= 4) ? argv[3] : default_argv[3];
const char * fn_chkpt_out = (argc >= 5) ? argv[4] : default_argv[4];
const char * fn_model_out = (argc >= 6) ? argv[5] : default_argv[5];
return true;
}
int main(int argc, char ** argv) {
struct train_params params = get_default_train_params();
if (!train_params_parse(argc, argv, &params)) {
return 1;
}
if (params.seed < 0) {
srand(time(NULL));
} else {
srand(params.seed);
}
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true;
struct llama_context * lctx = llama_init_from_file(fn_model, llama_params);
struct llama_context * lctx = llama_init_from_file(params.fn_vocab_model, llama_params);
struct llama_vocab vocab;
{
@ -2232,19 +2507,19 @@ int main(int argc, char ** argv) {
printf("%s: tokenize training data\n", __func__);
std::vector<llama_token> train_tokens;
if (tokenize_file(lctx, fn_train, train_tokens) < 0) {
fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, fn_train);
if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) {
fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, params.fn_train_data);
}
printf("%s: number of training tokens: %d\n", __func__, train_tokens.size());
struct my_llama_model model;
model.hparams.n_vocab = llama_n_vocab(lctx);
model.hparams.n_ctx = n_ctx;
model.hparams.n_embd = n_embd;
model.hparams.n_mult = n_mult;
model.hparams.n_head = n_head;
model.hparams.n_layer = n_layer;
model.hparams.n_rot = std::min((uint32_t)n_rotmax, model.hparams.n_embd / model.hparams.n_head);
model.hparams.n_ctx = params.n_ctx;
model.hparams.n_embd = params.n_embd;
model.hparams.n_mult = params.n_mult;
model.hparams.n_head = params.n_head;
model.hparams.n_layer = params.n_layer;
model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head);
print_params(&model.hparams);
@ -2282,6 +2557,7 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
int n_batch = params.n_batch;
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
memset(opt, 0, sizeof(struct ggml_opt_context));
@ -2290,32 +2566,32 @@ int main(int argc, char ** argv) {
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = n_threads;
opt_params_adam.adam.n_iter = adam_n_iter;
opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = adam_alpha;
opt_params_adam.adam.decay = adam_decay;
opt_params_adam.n_threads = params.n_threads;
opt_params_adam.adam.n_iter = params.adam_n_iter;
opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = params.adam_alpha;
opt_params_adam.adam.decay = params.adam_decay;
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
opt_params_lbfgs.n_threads = n_threads;
opt_params_lbfgs.lbfgs.n_iter = lbfgs_n_iter;
opt_params_lbfgs.n_threads = params.n_threads;
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
opt->ctx = model.ctx;
opt->params = use_adam ? opt_params_adam : opt_params_lbfgs;
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
printf("%s: init model\n", __func__);
bool existed = load_checkpoint(&model, opt, fn_chkpt_in, true);
bool existed = load_checkpoint(&model, opt, params.fn_checkpoint_in, true);
set_param_model(&model);
opt->params = use_adam ? opt_params_adam : opt_params_lbfgs;
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
opt->iter = model.train_its;
printf("%s: opt iter %d\n", __func__, opt->iter);
bool from_scratch = !existed;
if (from_scratch) {
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
}
init_kv_cache(&kv_self, &model, 1);
@ -2328,11 +2604,11 @@ int main(int argc, char ** argv) {
size_t compute_size = 1024ll*1024ll*1024ll*32ll;
uint8_t * compute_addr = new uint8_t[compute_size];
GGML_ASSERT(train_tokens.size() > n_tokens);;
std::vector<int> train_samples;
train_samples.push_back(0);
for (int i=1; i<train_tokens.size()-n_tokens; ++i) {
if (!samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) {
if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) {
train_samples.push_back(i);
}
}
@ -2343,7 +2619,7 @@ int main(int argc, char ** argv) {
printf("%s: begin training\n", __func__);
for (int ex=0; ex<n_examples; ++ex) {
for (int ex=0; ex<params.n_examples; ++ex) {
if (ex*n_batch >= train_samples.size()) {
shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
for (int i=0; i<train_samples.size(); ++i) {
@ -2351,12 +2627,12 @@ int main(int argc, char ** argv) {
}
}
struct ggml_init_params params = {
struct ggml_init_params cparams = {
/*.mem_size =*/ compute_size,
/*.mem_buffer =*/ compute_addr,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_context * ctx0 = ggml_init(cparams);
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
@ -2367,13 +2643,13 @@ int main(int argc, char ** argv) {
int n_past = 0;
ggml_cgraph gf = {};
gf.n_threads = n_threads;
gf.n_threads = params.n_threads;
get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
struct ggml_tensor * logits =
(n_past == 0)
? (use_flash
? (params.use_flash
? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)
: forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch))
: forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
@ -2387,9 +2663,14 @@ int main(int argc, char ** argv) {
float error_before_opt = ggml_get_f32_1d(e, 0);
opt->params.adam.sched = (opt->iter < warmup)
? (float) opt->iter / (float) warmup
: cosine_decay_restart(cos_decay_steps, cos_decay_alpha, opt->iter - warmup, cos_decay_restart);
opt->params.adam.sched = (opt->iter < params.warmup)
? (float) opt->iter / (float) params.warmup
: cosine_decay_restart(
params.cos_decay_steps,
params.cos_decay_alpha,
opt->iter - params.warmup,
params.cos_decay_restart);
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
// ggml_opt(ctx0, opt->params, e);
@ -2406,8 +2687,7 @@ int main(int argc, char ** argv) {
float error_after_opt = ggml_get_f32_1d(e, 0);
if (ex % print_info_interval == 0) {
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
printf("Example %d, opt iter %d\n", ex, opt->iter);
printf("error_before_opt: %.6f\n", error_before_opt);
printf("error_after_opt: %.6f\n", error_after_opt);
@ -2415,7 +2695,7 @@ int main(int argc, char ** argv) {
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
}
if (ex % print_details_interval == 0) {
if (params.print_details_interval > 0 && ex % params.print_details_interval == 0) {
// set_logits_masked(logits, token_notavail, -1e9);
for (int i=0; i<n_batch; ++i) {
init_sampler(&sampler, lctx);
@ -2443,16 +2723,16 @@ int main(int argc, char ** argv) {
ggml_free(ctx0);
}
if (n_examples > 0) {
save_checkpoint(&model, opt, fn_chkpt_out);
if (params.n_examples > 0) {
save_checkpoint(&model, opt, params.fn_checkpoint_out);
}
if (strlen(fn_model_out) > 0) {
save_as_llama_model(&vocab, &model, fn_model_out);
if (strlen(params.fn_model_out) > 0) {
save_as_llama_model(&vocab, &model, params.fn_model_out);
}
{
int n_gen = 1024;
int n_gen = params.n_predict;
int sample_ctx = n_tokens - n_tokens/8;
sampler.params.temp = 0.2;
@ -2477,15 +2757,15 @@ int main(int argc, char ** argv) {
printf("---\n");
for (int i=0; i<n_gen; ++i) {
struct ggml_init_params params = {
struct ggml_init_params cparams = {
/*.mem_size =*/ compute_size,
/*.mem_buffer =*/ compute_addr,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_context * ctx0 = ggml_init(cparams);
ggml_cgraph gf = {};
gf.n_threads = 1;
gf.n_threads = params.n_threads;
int n_past = 0;
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);