add train_params and command line option parser
This commit is contained in:
parent
fcbc4457d6
commit
ec8e262d1d
1 changed files with 367 additions and 87 deletions
|
@ -2150,66 +2150,341 @@ float cosine_decay_restart(int decay_steps, const float alpha, int step, float r
|
|||
return cosine_decay(decay_steps, alpha, step);
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
const char * default_model = "ggml-vic7b-uncensored-q4_0.bin";
|
||||
const char * default_train = "shakespeare.txt";
|
||||
const char * default_chkpt_in = "checkpoint.bin";
|
||||
const char * default_chkpt_out = "checkpoint.bin";
|
||||
const char * default_model_out = "ggml-checkpoint-f32.bin";
|
||||
const char * default_argv[6] = {argv[0], default_model, default_train, default_chkpt_in, default_chkpt_out, default_model_out};
|
||||
struct train_params {
|
||||
const char * fn_vocab_model;
|
||||
const char * fn_train_data;
|
||||
const char * fn_checkpoint_in;
|
||||
const char * fn_checkpoint_out;
|
||||
const char * fn_model_out;
|
||||
|
||||
if (argc < 6) {
|
||||
fprintf(stderr, "usage: %s model training_data chkpt_in chkpt_out model_out\n", argv[0]);
|
||||
//return 1;
|
||||
}
|
||||
int seed;
|
||||
int n_ctx;
|
||||
int n_embd;
|
||||
int n_mult;
|
||||
int n_head;
|
||||
int n_layer;
|
||||
int n_rotmax;
|
||||
|
||||
int seed = 1;
|
||||
int n_ctx = 256;
|
||||
// int n_ctx = 64;
|
||||
int n_embd = 256;
|
||||
int n_mult = 256;
|
||||
int n_head = 8;
|
||||
int n_layer = 16;
|
||||
int n_rotmax = 64;
|
||||
int n_threads;
|
||||
int n_batch;
|
||||
int n_examples;
|
||||
int n_predict;
|
||||
|
||||
int n_threads = 6;
|
||||
int n_batch = 8;
|
||||
int n_examples = 32;
|
||||
int print_info_interval;
|
||||
int print_details_interval;
|
||||
|
||||
int print_info_interval = 1;
|
||||
int print_details_interval = 2;
|
||||
|
||||
bool samples_start_after_nl = false;
|
||||
bool use_adam = true;
|
||||
bool use_flash = false;
|
||||
bool samples_start_after_nl;
|
||||
bool use_adam;
|
||||
bool use_flash;
|
||||
|
||||
// only adam
|
||||
int warmup = 100;
|
||||
int cos_decay_steps = 1000;
|
||||
float cos_decay_restart = 1.1f;
|
||||
float cos_decay_alpha = 0.0f;
|
||||
int warmup;
|
||||
int cos_decay_steps;
|
||||
float cos_decay_restart;
|
||||
float cos_decay_alpha;
|
||||
|
||||
int lbfgs_n_iter = 16;
|
||||
int adam_n_iter = 16;
|
||||
float adam_alpha = 1e-3;
|
||||
float adam_decay = 1e-3;
|
||||
int lbfgs_n_iter;
|
||||
int adam_n_iter;
|
||||
float adam_alpha;
|
||||
float adam_decay;
|
||||
};
|
||||
|
||||
if (seed < 0) {
|
||||
srand(time(NULL));
|
||||
} else {
|
||||
srand(seed);
|
||||
struct train_params get_default_train_params() {
|
||||
struct train_params params;
|
||||
params.fn_vocab_model = "ggml-vic7b-uncensored-q4_0.bin";
|
||||
params.fn_train_data = "shakespeare.txt";
|
||||
params.fn_checkpoint_in = "checkpoint.bin";
|
||||
params.fn_checkpoint_out = "checkpoint.bin";
|
||||
params.fn_model_out = "ggml-checkpoint-f32.bin";
|
||||
|
||||
params.seed = -1;
|
||||
|
||||
params.n_ctx = 128;
|
||||
params.n_embd = 256;
|
||||
params.n_mult = 256;
|
||||
params.n_head = 8;
|
||||
params.n_layer = 16;
|
||||
params.n_rotmax = 64;
|
||||
|
||||
params.n_threads = 6;
|
||||
params.n_batch = 8;
|
||||
params.n_examples = 8;
|
||||
params.n_predict = 1024;
|
||||
|
||||
params.print_info_interval = 1;
|
||||
params.print_details_interval = 2;
|
||||
|
||||
params.samples_start_after_nl = false;
|
||||
params.use_adam = true;
|
||||
params.use_flash = true;
|
||||
|
||||
// only adam
|
||||
params.warmup = 100;
|
||||
params.cos_decay_steps = 1000;
|
||||
params.cos_decay_restart = 1.1f;
|
||||
params.cos_decay_alpha = 0.0f;
|
||||
|
||||
params.lbfgs_n_iter = 16;
|
||||
params.adam_n_iter = 16;
|
||||
params.adam_alpha = 1e-3;
|
||||
params.adam_decay = 1e-3;
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
void train_print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " --vocab-model FNAME model path from which to load vocab (default '%s')\n", params->fn_vocab_model);
|
||||
fprintf(stderr, " --train-data FNAME path from which to load training data (default '%s')\n", params->fn_train_data);
|
||||
fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in);
|
||||
fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out);
|
||||
fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out);
|
||||
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
|
||||
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
||||
fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd);
|
||||
fprintf(stderr, " --mult N Mult size used for new models, influences feedforward size. (default %d)\n", params->n_mult);
|
||||
fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head);
|
||||
fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer);
|
||||
fprintf(stderr, " --rotmax N Maximal number Rope dimensions for new models (default %d)\n", params->n_rotmax);
|
||||
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
||||
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
||||
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
|
||||
fprintf(stderr, " --predict N Number of tokens to generate after training (default %d)\n", params->n_predict);
|
||||
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
|
||||
fprintf(stderr, " --print-details-interval N Print details during training each N examples (default %d)\n", params->print_details_interval);
|
||||
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
|
||||
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
|
||||
fprintf(stderr, " --use-adam Use Adam optimizer (default)\n");
|
||||
fprintf(stderr, " --no-flash Don't use flash attention.\n");
|
||||
fprintf(stderr, " --use-flash Use flash attention (default)\n");
|
||||
fprintf(stderr, " --warmup N Number of warmup steps (default %d)\n", params->warmup);
|
||||
fprintf(stderr, " --cos-decay-steps N Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
||||
fprintf(stderr, " --cos-decay-restart N Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
||||
fprintf(stderr, " --cos-decay-alpha N Cosine decay alpha (default %f)\n", params->cos_decay_alpha);
|
||||
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
|
||||
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
||||
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
||||
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||
bool invalid_param = false;
|
||||
std::string arg;
|
||||
struct train_params default_params = get_default_train_params();
|
||||
const std::string arg_prefix = "--";
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||
}
|
||||
|
||||
if (arg == "--vocab-model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->fn_vocab_model = argv[i];
|
||||
} else if (arg == "--train-data") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->fn_train_data = argv[i];
|
||||
} else if (arg == "--checkpoint-in") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->fn_checkpoint_in = argv[i];
|
||||
} else if (arg == "--checkpoint-out") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->fn_checkpoint_out = argv[i];
|
||||
} else if (arg == "--model-out") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->fn_model_out = argv[i];
|
||||
} else if (arg == "-s" || arg == "--seed") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->seed = std::stoi(argv[i]);
|
||||
} else if (arg == "-c" || arg == "--ctx") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_ctx = std::stoi(argv[i]);
|
||||
} else if (arg == "--embd") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_embd = std::stoi(argv[i]);
|
||||
} else if (arg == "--mult") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_mult = std::stoi(argv[i]);
|
||||
} else if (arg == "--head") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_head = std::stoi(argv[i]);
|
||||
} else if (arg == "--layer") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_layer = std::stoi(argv[i]);
|
||||
} else if (arg == "--rotmax") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_rotmax = std::stoi(argv[i]);
|
||||
} else if (arg == "-t" || arg == "--threads") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_threads = std::stoi(argv[i]);
|
||||
} else if (arg == "-b" || arg == "--batch") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_batch = std::stoi(argv[i]);
|
||||
} else if (arg == "-n" || arg == "--examples") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_examples = std::stoi(argv[i]);
|
||||
} else if (arg == "--predict") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->n_predict = std::stoi(argv[i]);
|
||||
} else if (arg == "--print-info-interval") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->print_info_interval = std::stoi(argv[i]);
|
||||
} else if (arg == "--print-details-interval") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->print_details_interval = std::stoi(argv[i]);
|
||||
} else if (arg == "--samples-after-nl") {
|
||||
params->samples_start_after_nl = true;
|
||||
} else if (arg == "--use-lbfgs") {
|
||||
params->use_adam = false;
|
||||
} else if (arg == "--use-adam") {
|
||||
params->use_adam = true;
|
||||
} else if (arg == "--no-flash") {
|
||||
params->use_flash = false;
|
||||
} else if (arg == "--use-flash") {
|
||||
params->use_flash = true;
|
||||
} else if (arg == "--warmup") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->warmup = std::stoi(argv[i]);
|
||||
} else if (arg == "--cos-decay-steps") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->cos_decay_steps = std::stof(argv[i]);
|
||||
} else if (arg == "--cos-decay-restart") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->cos_decay_restart = std::stof(argv[i]);
|
||||
} else if (arg == "--cos-decay-alpha") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->cos_decay_alpha = std::stof(argv[i]);
|
||||
} else if (arg == "--lbfgs-iter") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->lbfgs_n_iter = std::stoi(argv[i]);
|
||||
} else if (arg == "--adam-iter") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->adam_n_iter = std::stoi(argv[i]);
|
||||
} else if (arg == "--adam-alpha") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->adam_alpha = std::stof(argv[i]);
|
||||
} else if (arg == "--adam-decay") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->adam_decay = std::stof(argv[i]);
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
train_print_usage(argc, argv, &default_params);
|
||||
exit(0);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
train_print_usage(argc, argv, &default_params);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (invalid_param) {
|
||||
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
|
||||
train_print_usage(argc, argv, &default_params);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1];
|
||||
const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2];
|
||||
const char * fn_chkpt_in = (argc >= 4) ? argv[3] : default_argv[3];
|
||||
const char * fn_chkpt_out = (argc >= 5) ? argv[4] : default_argv[4];
|
||||
const char * fn_model_out = (argc >= 6) ? argv[5] : default_argv[5];
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
struct train_params params = get_default_train_params();
|
||||
|
||||
if (!train_params_parse(argc, argv, ¶ms)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
||||
if (params.seed < 0) {
|
||||
srand(time(NULL));
|
||||
} else {
|
||||
srand(params.seed);
|
||||
}
|
||||
|
||||
struct llama_context_params llama_params = llama_context_default_params();
|
||||
llama_params.vocab_only = true;
|
||||
|
||||
struct llama_context * lctx = llama_init_from_file(fn_model, llama_params);
|
||||
struct llama_context * lctx = llama_init_from_file(params.fn_vocab_model, llama_params);
|
||||
|
||||
struct llama_vocab vocab;
|
||||
{
|
||||
|
@ -2232,19 +2507,19 @@ int main(int argc, char ** argv) {
|
|||
|
||||
printf("%s: tokenize training data\n", __func__);
|
||||
std::vector<llama_token> train_tokens;
|
||||
if (tokenize_file(lctx, fn_train, train_tokens) < 0) {
|
||||
fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, fn_train);
|
||||
if (tokenize_file(lctx, params.fn_train_data, train_tokens) < 0) {
|
||||
fprintf(stderr, "%s: failed to tokenize file '%s'\n", __func__, params.fn_train_data);
|
||||
}
|
||||
printf("%s: number of training tokens: %d\n", __func__, train_tokens.size());
|
||||
|
||||
struct my_llama_model model;
|
||||
model.hparams.n_vocab = llama_n_vocab(lctx);
|
||||
model.hparams.n_ctx = n_ctx;
|
||||
model.hparams.n_embd = n_embd;
|
||||
model.hparams.n_mult = n_mult;
|
||||
model.hparams.n_head = n_head;
|
||||
model.hparams.n_layer = n_layer;
|
||||
model.hparams.n_rot = std::min((uint32_t)n_rotmax, model.hparams.n_embd / model.hparams.n_head);
|
||||
model.hparams.n_ctx = params.n_ctx;
|
||||
model.hparams.n_embd = params.n_embd;
|
||||
model.hparams.n_mult = params.n_mult;
|
||||
model.hparams.n_head = params.n_head;
|
||||
model.hparams.n_layer = params.n_layer;
|
||||
model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head);
|
||||
|
||||
print_params(&model.hparams);
|
||||
|
||||
|
@ -2282,6 +2557,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
int n_tokens = model.hparams.n_ctx;
|
||||
int n_vocab = model.hparams.n_vocab;
|
||||
int n_batch = params.n_batch;
|
||||
|
||||
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context));
|
||||
memset(opt, 0, sizeof(struct ggml_opt_context));
|
||||
|
@ -2290,32 +2566,32 @@ int main(int argc, char ** argv) {
|
|||
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
|
||||
opt_params_adam.print_forward_graph = false;
|
||||
opt_params_adam.print_backward_graph = false;
|
||||
opt_params_adam.n_threads = n_threads;
|
||||
opt_params_adam.adam.n_iter = adam_n_iter;
|
||||
opt_params_adam.adam.sched = 1.0f;
|
||||
opt_params_adam.adam.alpha = adam_alpha;
|
||||
opt_params_adam.adam.decay = adam_decay;
|
||||
opt_params_adam.n_threads = params.n_threads;
|
||||
opt_params_adam.adam.n_iter = params.adam_n_iter;
|
||||
opt_params_adam.adam.sched = 1.0f;
|
||||
opt_params_adam.adam.alpha = params.adam_alpha;
|
||||
opt_params_adam.adam.decay = params.adam_decay;
|
||||
|
||||
opt_params_lbfgs.print_forward_graph = false;
|
||||
opt_params_lbfgs.print_backward_graph = false;
|
||||
opt_params_lbfgs.n_threads = n_threads;
|
||||
opt_params_lbfgs.lbfgs.n_iter = lbfgs_n_iter;
|
||||
opt_params_lbfgs.n_threads = params.n_threads;
|
||||
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
|
||||
|
||||
opt->ctx = model.ctx;
|
||||
opt->params = use_adam ? opt_params_adam : opt_params_lbfgs;
|
||||
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
|
||||
|
||||
printf("%s: init model\n", __func__);
|
||||
bool existed = load_checkpoint(&model, opt, fn_chkpt_in, true);
|
||||
bool existed = load_checkpoint(&model, opt, params.fn_checkpoint_in, true);
|
||||
set_param_model(&model);
|
||||
|
||||
opt->params = use_adam ? opt_params_adam : opt_params_lbfgs;
|
||||
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
|
||||
|
||||
opt->iter = model.train_its;
|
||||
printf("%s: opt iter %d\n", __func__, opt->iter);
|
||||
|
||||
bool from_scratch = !existed;
|
||||
if (from_scratch) {
|
||||
randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f);
|
||||
randomize_model(&model, params.seed, 0.0f, 1.0f, -1.0f, +1.0f);
|
||||
}
|
||||
|
||||
init_kv_cache(&kv_self, &model, 1);
|
||||
|
@ -2328,11 +2604,11 @@ int main(int argc, char ** argv) {
|
|||
size_t compute_size = 1024ll*1024ll*1024ll*32ll;
|
||||
uint8_t * compute_addr = new uint8_t[compute_size];
|
||||
|
||||
|
||||
GGML_ASSERT(train_tokens.size() > n_tokens);;
|
||||
std::vector<int> train_samples;
|
||||
train_samples.push_back(0);
|
||||
for (int i=1; i<train_tokens.size()-n_tokens; ++i) {
|
||||
if (!samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) {
|
||||
if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) {
|
||||
train_samples.push_back(i);
|
||||
}
|
||||
}
|
||||
|
@ -2343,7 +2619,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
printf("%s: begin training\n", __func__);
|
||||
|
||||
for (int ex=0; ex<n_examples; ++ex) {
|
||||
for (int ex=0; ex<params.n_examples; ++ex) {
|
||||
if (ex*n_batch >= train_samples.size()) {
|
||||
shuffle_ints(train_samples.data(), train_samples.data() + train_samples.size());
|
||||
for (int i=0; i<train_samples.size(); ++i) {
|
||||
|
@ -2351,12 +2627,12 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params = {
|
||||
struct ggml_init_params cparams = {
|
||||
/*.mem_size =*/ compute_size,
|
||||
/*.mem_buffer =*/ compute_addr,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_context * ctx0 = ggml_init(cparams);
|
||||
|
||||
struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
|
||||
struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
|
||||
|
@ -2367,13 +2643,13 @@ int main(int argc, char ** argv) {
|
|||
int n_past = 0;
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = n_threads;
|
||||
gf.n_threads = params.n_threads;
|
||||
|
||||
get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
|
||||
|
||||
struct ggml_tensor * logits =
|
||||
(n_past == 0)
|
||||
? (use_flash
|
||||
? (params.use_flash
|
||||
? forward_batch_wo_cache_flash_attn(&model, ctx0, &gf, tokens_input, n_tokens, n_batch)
|
||||
: forward_batch_wo_cache(&model, ctx0, &gf, tokens_input, n_tokens, n_batch))
|
||||
: forward_batch(&model, &kv_self, ctx0, &gf, tokens_input, n_tokens, n_past, n_batch);
|
||||
|
@ -2387,9 +2663,14 @@ int main(int argc, char ** argv) {
|
|||
|
||||
float error_before_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
opt->params.adam.sched = (opt->iter < warmup)
|
||||
? (float) opt->iter / (float) warmup
|
||||
: cosine_decay_restart(cos_decay_steps, cos_decay_alpha, opt->iter - warmup, cos_decay_restart);
|
||||
opt->params.adam.sched = (opt->iter < params.warmup)
|
||||
? (float) opt->iter / (float) params.warmup
|
||||
: cosine_decay_restart(
|
||||
params.cos_decay_steps,
|
||||
params.cos_decay_alpha,
|
||||
opt->iter - params.warmup,
|
||||
params.cos_decay_restart);
|
||||
|
||||
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
|
||||
|
||||
// ggml_opt(ctx0, opt->params, e);
|
||||
|
@ -2406,8 +2687,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
float error_after_opt = ggml_get_f32_1d(e, 0);
|
||||
|
||||
|
||||
if (ex % print_info_interval == 0) {
|
||||
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
|
||||
printf("Example %d, opt iter %d\n", ex, opt->iter);
|
||||
printf("error_before_opt: %.6f\n", error_before_opt);
|
||||
printf("error_after_opt: %.6f\n", error_after_opt);
|
||||
|
@ -2415,7 +2695,7 @@ int main(int argc, char ** argv) {
|
|||
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
|
||||
}
|
||||
|
||||
if (ex % print_details_interval == 0) {
|
||||
if (params.print_details_interval > 0 && ex % params.print_details_interval == 0) {
|
||||
// set_logits_masked(logits, token_notavail, -1e9);
|
||||
for (int i=0; i<n_batch; ++i) {
|
||||
init_sampler(&sampler, lctx);
|
||||
|
@ -2443,16 +2723,16 @@ int main(int argc, char ** argv) {
|
|||
ggml_free(ctx0);
|
||||
}
|
||||
|
||||
if (n_examples > 0) {
|
||||
save_checkpoint(&model, opt, fn_chkpt_out);
|
||||
if (params.n_examples > 0) {
|
||||
save_checkpoint(&model, opt, params.fn_checkpoint_out);
|
||||
}
|
||||
|
||||
if (strlen(fn_model_out) > 0) {
|
||||
save_as_llama_model(&vocab, &model, fn_model_out);
|
||||
if (strlen(params.fn_model_out) > 0) {
|
||||
save_as_llama_model(&vocab, &model, params.fn_model_out);
|
||||
}
|
||||
|
||||
{
|
||||
int n_gen = 1024;
|
||||
int n_gen = params.n_predict;
|
||||
int sample_ctx = n_tokens - n_tokens/8;
|
||||
|
||||
sampler.params.temp = 0.2;
|
||||
|
@ -2477,15 +2757,15 @@ int main(int argc, char ** argv) {
|
|||
|
||||
printf("---\n");
|
||||
for (int i=0; i<n_gen; ++i) {
|
||||
struct ggml_init_params params = {
|
||||
struct ggml_init_params cparams = {
|
||||
/*.mem_size =*/ compute_size,
|
||||
/*.mem_buffer =*/ compute_addr,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_context * ctx0 = ggml_init(cparams);
|
||||
|
||||
ggml_cgraph gf = {};
|
||||
gf.n_threads = 1;
|
||||
gf.n_threads = params.n_threads;
|
||||
|
||||
int n_past = 0;
|
||||
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue