use optimization callback in training

allows dynamic learning schedule and different batch data for each iteration without relying on low n_iter and high n_examples parameters

reduces runtime by avoiding restart of optimization function and improves training convergence by providing a different batch for each iteration
This commit is contained in:
xaedes 2023-07-02 22:18:50 +02:00
parent bfc3119139
commit d7aa4d9576
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -3418,7 +3418,7 @@ struct train_params get_default_train_params() {
params.n_threads = 6; params.n_threads = 6;
params.n_batch = 8; params.n_batch = 8;
params.n_examples = 8; params.n_examples = 1;
params.n_predict = 1024; params.n_predict = 1024;
params.print_info_interval = 1; params.print_info_interval = 1;
@ -3441,8 +3441,8 @@ struct train_params get_default_train_params() {
params.cos_decay_alpha = 0.0f; params.cos_decay_alpha = 0.0f;
params.enable_restart = false; params.enable_restart = false;
params.lbfgs_n_iter = 16; params.lbfgs_n_iter = 256;
params.adam_n_iter = 16; params.adam_n_iter = 256;
params.adam_alpha = 1e-3f; params.adam_alpha = 1e-3f;
params.adam_min_alpha = 1e-4f; params.adam_min_alpha = 1e-4f;
params.adam_decay = 1e-1f; params.adam_decay = 1e-1f;
@ -3803,6 +3803,61 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
return true; return true;
} }
struct opt_callback_data {
struct train_params * params;
struct ggml_opt_context * opt;
llama_token * tokens_data;
size_t tokens_size;
int * samples_data;
size_t samples_size;
int shuffle_countdown;
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_logits;
struct ggml_tensor * target_probs;
};
void opt_callback(void * vdata, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params * params = data->params;
struct ggml_opt_context * opt = data->opt;
int n_batch = params->n_batch;
*sched = (opt->iter < params->warmup)
? (float) opt->iter / (float) params->warmup
: cosine_decay_restart(
params->cos_decay_steps,
params->cos_decay_alpha,
opt->iter - params->warmup,
params->cos_decay_restart,
params->enable_restart);
float min_sched = params->adam_min_alpha / params->adam_alpha;
*sched = min_sched + *sched * (1.0f - min_sched);
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0);
if (data->shuffle_countdown < n_batch) {
printf("%s: reshuffle samples\n", __func__);
shuffle_ints(data->samples_data, data->samples_data + data->samples_size);
for (int i = 0; i < (int) data->samples_size; ++i) {
GGML_ASSERT(data->samples_data[i]+params->n_ctx-1 < (int) data->tokens_size);
}
data->shuffle_countdown = data->samples_size;
}
get_example_targets_batch(
data->samples_data,
data->samples_size,
data->tokens_data,
data->tokens_size,
opt->iter,
data->tokens_input,
data->target_logits,
data->target_probs);
data->shuffle_countdown -= n_batch;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
struct train_params params = get_default_train_params(); struct train_params params = get_default_train_params();
@ -3975,6 +4030,18 @@ int main(int argc, char ** argv) {
printf("%s: begin training\n", __func__); printf("%s: begin training\n", __func__);
struct opt_callback_data opt_cb_data;
opt_cb_data.params = &params;
opt_cb_data.opt = opt;
opt_cb_data.tokens_data = train_tokens.data();
opt_cb_data.tokens_size = train_tokens.size();
opt_cb_data.samples_data = train_samples.data();
opt_cb_data.samples_size = train_samples.size();
opt_cb_data.shuffle_countdown = train_samples.size();
opt_cb_data.tokens_input = NULL;
opt_cb_data.target_logits = NULL;
opt_cb_data.target_probs = NULL;
int64_t t0 = ggml_time_ms(); int64_t t0 = ggml_time_ms();
for (int ex = 0; ex < params.n_examples; ++ex) { for (int ex = 0; ex < params.n_examples; ++ex) {
@ -3998,6 +4065,10 @@ int main(int argc, char ** argv) {
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
opt_cb_data.tokens_input = tokens_input;
opt_cb_data.target_logits = target_logits;
opt_cb_data.target_probs = target_probs;
int n_past = 0; int n_past = 0;
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
@ -4009,8 +4080,6 @@ int main(int argc, char ** argv) {
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data;
get_example_targets_batch(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs);
GGML_ASSERT(n_past == 0); GGML_ASSERT(n_past == 0);
struct ggml_tensor * loss = NULL; struct ggml_tensor * loss = NULL;
@ -4062,7 +4131,7 @@ int main(int argc, char ** argv) {
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
ggml_opt_resume_g(ctx0, opt, loss, gf, gb, NULL, NULL); ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &opt_callback, (void *) &opt_cb_data);
size_t used_mem_after_opt = ggml_used_mem(ctx0); size_t used_mem_after_opt = ggml_used_mem(ctx0);