add gradient accumulation
specify number accumulation steps with '--grad-acc N'. this will simulate a bigger batch size of grad_acc*batch.
This commit is contained in:
parent
d3afd7131e
commit
c1c3b0e0c2
4 changed files with 264 additions and 198 deletions
|
@ -1730,6 +1730,7 @@ struct train_params {
|
||||||
int n_ctx;
|
int n_ctx;
|
||||||
int n_threads;
|
int n_threads;
|
||||||
int n_batch;
|
int n_batch;
|
||||||
|
int n_gradient_accumulation;
|
||||||
|
|
||||||
bool custom_n_ctx;
|
bool custom_n_ctx;
|
||||||
|
|
||||||
|
@ -1804,6 +1805,7 @@ struct train_params get_default_train_params() {
|
||||||
params.n_ctx = 128;
|
params.n_ctx = 128;
|
||||||
params.n_threads = 6;
|
params.n_threads = 6;
|
||||||
params.n_batch = 8;
|
params.n_batch = 8;
|
||||||
|
params.n_gradient_accumulation = 1;
|
||||||
|
|
||||||
params.custom_n_ctx = false;
|
params.custom_n_ctx = false;
|
||||||
|
|
||||||
|
@ -1880,6 +1882,7 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
||||||
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx);
|
||||||
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
||||||
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
||||||
|
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
|
||||||
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
|
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
|
||||||
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
||||||
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
||||||
|
@ -2015,6 +2018,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_batch = std::stoi(argv[i]);
|
params->n_batch = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--grad-acc") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->n_gradient_accumulation = std::stoi(argv[i]);
|
||||||
} else if (arg == "--norm-rms-eps") {
|
} else if (arg == "--norm-rms-eps") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -2299,83 +2308,85 @@ void print_duration(float fmillis) {
|
||||||
printf("%02lld:%02lld:%02lld", hours, minutes, seconds);
|
printf("%02lld:%02lld:%02lld", hours, minutes, seconds);
|
||||||
}
|
}
|
||||||
|
|
||||||
void opt_callback(void * vdata, float * sched) {
|
void opt_callback(void * vdata, int accum_step, float * sched) {
|
||||||
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
|
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
|
||||||
struct train_params * params = data->params;
|
struct train_params * params = data->params;
|
||||||
struct ggml_opt_context * opt = data->opt;
|
struct ggml_opt_context * opt = data->opt;
|
||||||
int n_batch = params->n_batch;
|
int n_batch = params->n_batch;
|
||||||
int n_ctx = params->n_ctx;
|
int n_ctx = params->n_ctx;
|
||||||
|
|
||||||
int64_t now = ggml_time_ms();
|
if (accum_step == 0) {
|
||||||
if (now > data->last_time) {
|
int64_t now = ggml_time_ms();
|
||||||
float dt = now - data->last_time;
|
if (now > data->last_time) {
|
||||||
if (data->time_per_iter == 0) {
|
float dt = now - data->last_time;
|
||||||
data->time_per_iter = dt;
|
if (data->time_per_iter == 0) {
|
||||||
} else {
|
data->time_per_iter = dt;
|
||||||
const float gain = 0.7f;
|
} else {
|
||||||
data->time_per_iter = data->time_per_iter*(1.0f-gain) + dt*gain;
|
const float gain = 0.7f;
|
||||||
|
data->time_per_iter = data->time_per_iter*(1.0f-gain) + dt*gain;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
data->last_time = now;
|
||||||
data->last_time = now;
|
float remaining_time = 0;
|
||||||
float remaining_time = 0;
|
if (data->time_per_iter > 0) {
|
||||||
if (data->time_per_iter > 0) {
|
const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter;
|
||||||
const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter;
|
const int done_iter = opt->iter - data->first_iter;
|
||||||
const int done_iter = opt->iter - data->first_iter;
|
const int remaining_iter = n_iter - done_iter;
|
||||||
const int remaining_iter = n_iter - done_iter;
|
remaining_time = remaining_iter * data->time_per_iter;
|
||||||
remaining_time = remaining_iter * data->time_per_iter;
|
|
||||||
}
|
|
||||||
|
|
||||||
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
|
||||||
if (save_now) {
|
|
||||||
int new_iters = opt->iter - data->last_save_iter;
|
|
||||||
data->lora->train_its += new_iters;
|
|
||||||
data->lora->train_samples += new_iters * n_batch;
|
|
||||||
data->lora->train_tokens += new_iters * n_batch * n_ctx;
|
|
||||||
|
|
||||||
if (strlen(params->fn_checkpoint_out) > 0) {
|
|
||||||
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, opt->iter, params->fn_latest);
|
|
||||||
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, -1, params->fn_latest);
|
|
||||||
}
|
}
|
||||||
if (strlen(params->fn_lora_out) > 0) {
|
|
||||||
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest);
|
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
||||||
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest);
|
if (save_now) {
|
||||||
|
int new_iters = opt->iter - data->last_save_iter;
|
||||||
|
data->lora->train_its += new_iters;
|
||||||
|
data->lora->train_samples += new_iters * n_batch;
|
||||||
|
data->lora->train_tokens += new_iters * n_batch * n_ctx;
|
||||||
|
|
||||||
|
if (strlen(params->fn_checkpoint_out) > 0) {
|
||||||
|
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||||
|
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, -1, params->fn_latest);
|
||||||
|
}
|
||||||
|
if (strlen(params->fn_lora_out) > 0) {
|
||||||
|
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||||
|
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest);
|
||||||
|
}
|
||||||
|
data->last_save_iter = opt->iter;
|
||||||
}
|
}
|
||||||
data->last_save_iter = opt->iter;
|
|
||||||
}
|
|
||||||
|
|
||||||
*sched = (opt->iter < params->warmup)
|
*sched = (opt->iter < params->warmup)
|
||||||
? (float) opt->iter / (float) params->warmup
|
? (float) opt->iter / (float) params->warmup
|
||||||
: cosine_decay_restart(
|
: cosine_decay_restart(
|
||||||
params->cos_decay_steps,
|
params->cos_decay_steps,
|
||||||
params->cos_decay_min,
|
params->cos_decay_min,
|
||||||
opt->iter - params->warmup,
|
opt->iter - params->warmup,
|
||||||
params->cos_decay_restart,
|
params->cos_decay_restart,
|
||||||
params->enable_restart);
|
params->enable_restart);
|
||||||
float min_sched = params->adam_min_alpha / params->adam_alpha;
|
float min_sched = params->adam_min_alpha / params->adam_alpha;
|
||||||
*sched = min_sched + *sched * (1.0f - min_sched);
|
*sched = min_sched + *sched * (1.0f - min_sched);
|
||||||
|
|
||||||
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
|
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
|
||||||
if (impr_plot > 0) impr_plot = 0;
|
if (impr_plot > 0) impr_plot = 0;
|
||||||
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
|
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
|
||||||
printf("%s: iter=%*d sched=%f loss=%f",
|
printf("%s: iter=%*d sched=%f loss=%f",
|
||||||
__func__, 6, opt->iter, *sched, opt->loss_after);
|
__func__, 6, opt->iter, *sched, opt->loss_after);
|
||||||
if (data->time_per_iter > 0) {
|
if (data->time_per_iter > 0) {
|
||||||
printf(" dt=");
|
printf(" dt=");
|
||||||
print_duration(data->time_per_iter);
|
print_duration(data->time_per_iter);
|
||||||
printf(" eta=");
|
printf(" eta=");
|
||||||
print_duration(remaining_time);
|
print_duration(remaining_time);
|
||||||
}
|
}
|
||||||
|
|
||||||
float improvement = opt->loss_before - opt->loss_after;
|
float improvement = opt->loss_before - opt->loss_after;
|
||||||
const float plot_scale = 10.0f;
|
const float plot_scale = 10.0f;
|
||||||
int bar_len = (int)(1 + improvement*plot_scale + 0.5);
|
int bar_len = (int)(1 + improvement*plot_scale + 0.5);
|
||||||
printf(" |");
|
printf(" |");
|
||||||
for (int i=0; i<bar_len; ++i) {
|
for (int i=0; i<bar_len; ++i) {
|
||||||
printf("-");
|
printf("-");
|
||||||
|
}
|
||||||
|
printf(">");
|
||||||
|
// printf("improvement: %*d>", impr_plot, (int)0);
|
||||||
|
printf("\n");
|
||||||
}
|
}
|
||||||
printf(">");
|
|
||||||
// printf("improvement: %*d>", impr_plot, (int)0);
|
|
||||||
printf("\n");
|
|
||||||
|
|
||||||
if (data->shuffle_countdown < n_batch) {
|
if (data->shuffle_countdown < n_batch) {
|
||||||
printf("%s: reshuffle samples\n", __func__);
|
printf("%s: reshuffle samples\n", __func__);
|
||||||
|
@ -2491,30 +2502,32 @@ int main(int argc, char ** argv) {
|
||||||
// set opt params from command line
|
// set opt params from command line
|
||||||
if (params.use_adam) {
|
if (params.use_adam) {
|
||||||
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
|
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
|
||||||
opt->params.print_forward_graph = false;
|
opt->params.print_forward_graph = false;
|
||||||
opt->params.print_backward_graph = false;
|
opt->params.print_backward_graph = false;
|
||||||
opt->params.n_threads = params.n_threads;
|
opt->params.n_threads = params.n_threads;
|
||||||
opt->params.past = params.opt_past;
|
opt->params.past = params.opt_past;
|
||||||
opt->params.delta = params.opt_delta;
|
opt->params.delta = params.opt_delta;
|
||||||
opt->params.max_no_improvement = params.opt_max_no_improvement;
|
opt->params.max_no_improvement = params.opt_max_no_improvement;
|
||||||
opt->params.adam.n_iter = params.adam_n_iter;
|
opt->params.n_gradient_accumulation = params.n_gradient_accumulation;
|
||||||
opt->params.adam.sched = 1.0f;
|
opt->params.adam.n_iter = params.adam_n_iter;
|
||||||
opt->params.adam.alpha = params.adam_alpha;
|
opt->params.adam.sched = 1.0f;
|
||||||
opt->params.adam.decay = params.adam_decay;
|
opt->params.adam.alpha = params.adam_alpha;
|
||||||
opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim;
|
opt->params.adam.decay = params.adam_decay;
|
||||||
opt->params.adam.beta1 = params.adam_beta1;
|
opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim;
|
||||||
opt->params.adam.beta2 = params.adam_beta2;
|
opt->params.adam.beta1 = params.adam_beta1;
|
||||||
opt->params.adam.gclip = params.adam_gclip;
|
opt->params.adam.beta2 = params.adam_beta2;
|
||||||
opt->params.adam.eps_f = params.adam_eps_f;
|
opt->params.adam.gclip = params.adam_gclip;
|
||||||
|
opt->params.adam.eps_f = params.adam_eps_f;
|
||||||
} else {
|
} else {
|
||||||
opt->params = ggml_opt_default_params(GGML_OPT_LBFGS);
|
opt->params = ggml_opt_default_params(GGML_OPT_LBFGS);
|
||||||
opt->params.print_forward_graph = false;
|
opt->params.print_forward_graph = false;
|
||||||
opt->params.print_backward_graph = false;
|
opt->params.print_backward_graph = false;
|
||||||
opt->params.n_threads = params.n_threads;
|
opt->params.n_threads = params.n_threads;
|
||||||
opt->params.past = params.opt_past;
|
opt->params.past = params.opt_past;
|
||||||
opt->params.delta = params.opt_delta;
|
opt->params.delta = params.opt_delta;
|
||||||
opt->params.max_no_improvement = params.opt_max_no_improvement;
|
opt->params.max_no_improvement = params.opt_max_no_improvement;
|
||||||
opt->params.lbfgs.n_iter = params.lbfgs_n_iter;
|
opt->params.n_gradient_accumulation = params.n_gradient_accumulation;
|
||||||
|
opt->params.lbfgs.n_iter = params.lbfgs_n_iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_allocr * alloc = NULL;
|
ggml_allocr * alloc = NULL;
|
||||||
|
|
|
@ -1299,8 +1299,9 @@ struct train_params {
|
||||||
int n_ff;
|
int n_ff;
|
||||||
|
|
||||||
int n_threads;
|
int n_threads;
|
||||||
int n_batch;
|
|
||||||
int n_examples;
|
int n_examples;
|
||||||
|
int n_batch;
|
||||||
|
int n_gradient_accumulation;
|
||||||
|
|
||||||
float f_norm_rms_eps;
|
float f_norm_rms_eps;
|
||||||
float rope_freq_base;
|
float rope_freq_base;
|
||||||
|
@ -1362,8 +1363,9 @@ struct train_params get_default_train_params() {
|
||||||
params.n_ff = 768;
|
params.n_ff = 768;
|
||||||
|
|
||||||
params.n_threads = 6;
|
params.n_threads = 6;
|
||||||
params.n_batch = 8;
|
|
||||||
params.n_examples = 1;
|
params.n_examples = 1;
|
||||||
|
params.n_batch = 8;
|
||||||
|
params.n_gradient_accumulation = 1;
|
||||||
|
|
||||||
params.f_norm_rms_eps = 1e-5f;
|
params.f_norm_rms_eps = 1e-5f;
|
||||||
params.rope_freq_base = 10000.0f;
|
params.rope_freq_base = 10000.0f;
|
||||||
|
@ -1428,8 +1430,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
||||||
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
||||||
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
||||||
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
||||||
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
|
||||||
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
|
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
|
||||||
|
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
||||||
|
fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation);
|
||||||
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
|
fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval);
|
||||||
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
|
fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off");
|
||||||
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
|
fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n");
|
||||||
|
@ -1591,6 +1594,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_batch = std::stoi(argv[i]);
|
params->n_batch = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--grad-acc") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->n_gradient_accumulation = std::stoi(argv[i]);
|
||||||
} else if (arg == "-n" || arg == "--examples") {
|
} else if (arg == "-n" || arg == "--examples") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1779,46 +1788,49 @@ struct opt_callback_data {
|
||||||
struct ggml_tensor * target_probs;
|
struct ggml_tensor * target_probs;
|
||||||
};
|
};
|
||||||
|
|
||||||
void opt_callback(void * vdata, float * sched) {
|
void opt_callback(void * vdata, int accum_step, float * sched) {
|
||||||
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
|
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
|
||||||
struct train_params * params = data->params;
|
struct train_params * params = data->params;
|
||||||
struct ggml_opt_context * opt = data->opt;
|
struct ggml_opt_context * opt = data->opt;
|
||||||
int n_batch = params->n_batch;
|
int n_batch = params->n_batch;
|
||||||
int n_ctx = params->n_ctx;
|
int n_ctx = params->n_ctx;
|
||||||
|
|
||||||
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
if (accum_step == 0) {
|
||||||
if (save_now) {
|
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
||||||
int new_iters = opt->iter - data->last_save_iter;
|
if (save_now) {
|
||||||
data->model->train_its += new_iters;
|
int new_iters = opt->iter - data->last_save_iter;
|
||||||
data->model->train_samples += new_iters * n_batch;
|
data->model->train_its += new_iters;
|
||||||
data->model->train_tokens += new_iters * n_batch * n_ctx;
|
data->model->train_samples += new_iters * n_batch;
|
||||||
|
data->model->train_tokens += new_iters * n_batch * n_ctx;
|
||||||
|
|
||||||
if (strlen(params->fn_checkpoint_out) > 0) {
|
if (strlen(params->fn_checkpoint_out) > 0) {
|
||||||
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest);
|
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||||
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, -1, params->fn_latest);
|
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, -1, params->fn_latest);
|
||||||
|
|
||||||
|
}
|
||||||
|
if (strlen(params->fn_model_out) > 0) {
|
||||||
|
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||||
|
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, -1, params->fn_latest);
|
||||||
|
}
|
||||||
|
data->last_save_iter = opt->iter;
|
||||||
}
|
}
|
||||||
if (strlen(params->fn_model_out) > 0) {
|
|
||||||
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, opt->iter, params->fn_latest);
|
*sched = (opt->iter < params->warmup)
|
||||||
save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, -1, params->fn_latest);
|
? (float) opt->iter / (float) params->warmup
|
||||||
}
|
: cosine_decay_restart(
|
||||||
data->last_save_iter = opt->iter;
|
params->cos_decay_steps,
|
||||||
|
params->cos_decay_min,
|
||||||
|
opt->iter - params->warmup,
|
||||||
|
params->cos_decay_restart,
|
||||||
|
params->enable_restart);
|
||||||
|
float min_sched = params->adam_min_alpha / params->adam_alpha;
|
||||||
|
*sched = min_sched + *sched * (1.0f - min_sched);
|
||||||
|
|
||||||
|
int impr_plot = std::isnan(opt->loss_after) ? 0 : -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
|
||||||
|
printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*sched = (opt->iter < params->warmup)
|
|
||||||
? (float) opt->iter / (float) params->warmup
|
|
||||||
: cosine_decay_restart(
|
|
||||||
params->cos_decay_steps,
|
|
||||||
params->cos_decay_min,
|
|
||||||
opt->iter - params->warmup,
|
|
||||||
params->cos_decay_restart,
|
|
||||||
params->enable_restart);
|
|
||||||
float min_sched = params->adam_min_alpha / params->adam_alpha;
|
|
||||||
*sched = min_sched + *sched * (1.0f - min_sched);
|
|
||||||
|
|
||||||
int impr_plot = std::isnan(opt->loss_after) ? 0 : -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
|
|
||||||
printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0);
|
|
||||||
|
|
||||||
if (data->shuffle_countdown < n_batch) {
|
if (data->shuffle_countdown < n_batch) {
|
||||||
printf("%s: reshuffle samples\n", __func__);
|
printf("%s: reshuffle samples\n", __func__);
|
||||||
shuffle_ints(data->samples_data, data->samples_data + data->samples_size);
|
shuffle_ints(data->samples_data, data->samples_data + data->samples_size);
|
||||||
|
@ -1917,29 +1929,31 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
|
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
|
||||||
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
|
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
|
||||||
opt_params_adam.print_forward_graph = false;
|
opt_params_adam.print_forward_graph = false;
|
||||||
opt_params_adam.print_backward_graph = false;
|
opt_params_adam.print_backward_graph = false;
|
||||||
opt_params_adam.n_threads = params.n_threads;
|
opt_params_adam.n_threads = params.n_threads;
|
||||||
opt_params_adam.past = params.opt_past;
|
opt_params_adam.past = params.opt_past;
|
||||||
opt_params_adam.delta = params.opt_delta;
|
opt_params_adam.delta = params.opt_delta;
|
||||||
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
|
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
|
||||||
opt_params_adam.adam.n_iter = params.adam_n_iter;
|
opt_params_adam.n_gradient_accumulation = params.n_gradient_accumulation;
|
||||||
opt_params_adam.adam.sched = 1.0f;
|
opt_params_adam.adam.n_iter = params.adam_n_iter;
|
||||||
opt_params_adam.adam.alpha = params.adam_alpha;
|
opt_params_adam.adam.sched = 1.0f;
|
||||||
opt_params_adam.adam.decay = params.adam_decay;
|
opt_params_adam.adam.alpha = params.adam_alpha;
|
||||||
opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim;
|
opt_params_adam.adam.decay = params.adam_decay;
|
||||||
opt_params_adam.adam.beta1 = params.adam_beta1;
|
opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim;
|
||||||
opt_params_adam.adam.beta2 = params.adam_beta2;
|
opt_params_adam.adam.beta1 = params.adam_beta1;
|
||||||
opt_params_adam.adam.gclip = params.adam_gclip;
|
opt_params_adam.adam.beta2 = params.adam_beta2;
|
||||||
opt_params_adam.adam.eps_f = params.adam_eps_f;
|
opt_params_adam.adam.gclip = params.adam_gclip;
|
||||||
|
opt_params_adam.adam.eps_f = params.adam_eps_f;
|
||||||
|
|
||||||
opt_params_lbfgs.print_forward_graph = false;
|
opt_params_lbfgs.print_forward_graph = false;
|
||||||
opt_params_lbfgs.print_backward_graph = false;
|
opt_params_lbfgs.print_backward_graph = false;
|
||||||
opt_params_lbfgs.n_threads = params.n_threads;
|
opt_params_lbfgs.n_threads = params.n_threads;
|
||||||
opt_params_adam.past = params.opt_past;
|
opt_params_lbfgs.past = params.opt_past;
|
||||||
opt_params_adam.delta = params.opt_delta;
|
opt_params_lbfgs.delta = params.opt_delta;
|
||||||
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
|
opt_params_lbfgs.max_no_improvement = params.opt_max_no_improvement;
|
||||||
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
|
opt_params_lbfgs.n_gradient_accumulation = params.n_gradient_accumulation;
|
||||||
|
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
|
||||||
|
|
||||||
opt->ctx = model.ctx;
|
opt->ctx = model.ctx;
|
||||||
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
|
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
|
||||||
|
|
150
ggml.c
150
ggml.c
|
@ -19112,7 +19112,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float *
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
|
static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) {
|
||||||
int i = 0;
|
int64_t i = 0;
|
||||||
for (int p = 0; p < np; ++p) {
|
for (int p = 0; p < np; ++p) {
|
||||||
const int64_t ne = ggml_nelements(ps[p]) ;
|
const int64_t ne = ggml_nelements(ps[p]) ;
|
||||||
// TODO: add function to get all elements at once
|
// TODO: add function to get all elements at once
|
||||||
|
@ -19122,6 +19122,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) {
|
||||||
|
int64_t i = 0;
|
||||||
|
for (int p = 0; p < np; ++p) {
|
||||||
|
const int64_t ne = ggml_nelements(ps[p]) ;
|
||||||
|
// TODO: add function to get all elements at once
|
||||||
|
for (int64_t j = 0; j < ne; ++j) {
|
||||||
|
g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// ADAM
|
// ADAM
|
||||||
//
|
//
|
||||||
|
@ -19170,26 +19181,37 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||||
const float eps = params.adam.eps;
|
const float eps = params.adam.eps;
|
||||||
const float gclip = params.adam.gclip;
|
const float gclip = params.adam.gclip;
|
||||||
const int decay_min_ndim = params.adam.decay_min_ndim;
|
const int decay_min_ndim = params.adam.decay_min_ndim;
|
||||||
|
const int n_accum = MAX(1, params.n_gradient_accumulation);
|
||||||
|
const float accum_norm = 1.0f / (float) n_accum;
|
||||||
|
|
||||||
|
float * g = opt->adam.g->data; // gradients
|
||||||
float * m = opt->adam.m->data; // first moment
|
float * m = opt->adam.m->data; // first moment
|
||||||
float * v = opt->adam.v->data; // second moment
|
float * v = opt->adam.v->data; // second moment
|
||||||
|
|
||||||
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
||||||
|
|
||||||
if (callback) {
|
|
||||||
callback(callback_data, &sched);
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute the function value
|
|
||||||
// ggml_graph_reset (gf);
|
|
||||||
ggml_set_f32 (f->grad, 1.0f);
|
|
||||||
|
|
||||||
struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
|
struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
|
||||||
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
|
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
|
||||||
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
|
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
|
||||||
ggml_graph_compute(gb, &cplan);
|
|
||||||
|
|
||||||
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
|
||||||
|
// compute the function value
|
||||||
|
|
||||||
|
float fx = 0;
|
||||||
|
ggml_set_zero(opt->adam.g);
|
||||||
|
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
||||||
|
if (callback) {
|
||||||
|
callback(callback_data, accum_step, &sched);
|
||||||
|
}
|
||||||
|
// ggml_graph_reset (gf);
|
||||||
|
ggml_set_f32 (f->grad, 1.0f);
|
||||||
|
ggml_graph_compute(gb, &cplan);
|
||||||
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
||||||
|
fx += ggml_get_f32_1d(f, 0);
|
||||||
|
}
|
||||||
|
fx *= accum_norm;
|
||||||
|
|
||||||
|
opt->adam.fx_prev = fx;
|
||||||
opt->adam.fx_best = opt->adam.fx_prev;
|
opt->adam.fx_best = opt->adam.fx_prev;
|
||||||
if (pf) {
|
if (pf) {
|
||||||
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
||||||
|
@ -19234,12 +19256,8 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||||
if (gclip > 0.0f) {
|
if (gclip > 0.0f) {
|
||||||
// gradient clipping
|
// gradient clipping
|
||||||
ggml_float sum = 0.0;
|
ggml_float sum = 0.0;
|
||||||
for (int p = 0; p < np; ++p) {
|
for (int64_t i = 0; i < nx; ++i) {
|
||||||
const int64_t ne = ggml_nelements(ps[p]);
|
sum += (ggml_float)(g[i]*g[i]);
|
||||||
for (int64_t j = 0; j < ne; ++j) {
|
|
||||||
float g = ggml_get_f32_1d(ps[p]->grad, j);
|
|
||||||
sum += (ggml_float)(g*g);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
ggml_float norm = sqrt(sum);
|
ggml_float norm = sqrt(sum);
|
||||||
if (norm > (ggml_float) gclip) {
|
if (norm > (ggml_float) gclip) {
|
||||||
|
@ -19253,10 +19271,10 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||||
const int64_t ne = ggml_nelements(ps[p]);
|
const int64_t ne = ggml_nelements(ps[p]);
|
||||||
const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
|
const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched;
|
||||||
for (int64_t j = 0; j < ne; ++j) {
|
for (int64_t j = 0; j < ne; ++j) {
|
||||||
float x = ggml_get_f32_1d(ps[p], j);
|
float x = ggml_get_f32_1d(ps[p], j);
|
||||||
float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
|
float g_ = g[i]*gnorm;
|
||||||
m[i] = m[i]*beta1 + g*(1.0f - beta1);
|
m[i] = m[i]*beta1 + g_*(1.0f - beta1);
|
||||||
v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
|
v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2);
|
||||||
float mh = m[i]*beta1h;
|
float mh = m[i]*beta1h;
|
||||||
float vh = v[i]*beta2h;
|
float vh = v[i]*beta2h;
|
||||||
vh = sqrtf(vh) + eps;
|
vh = sqrtf(vh) + eps;
|
||||||
|
@ -19267,16 +19285,20 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (callback) {
|
fx = 0;
|
||||||
callback(callback_data, &sched);
|
ggml_set_zero(opt->adam.g);
|
||||||
|
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
||||||
|
if (callback) {
|
||||||
|
callback(callback_data, accum_step, &sched);
|
||||||
|
}
|
||||||
|
// ggml_graph_reset (gf);
|
||||||
|
ggml_set_f32 (f->grad, 1.0f);
|
||||||
|
ggml_graph_compute(gb, &cplan);
|
||||||
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
||||||
|
fx += ggml_get_f32_1d(f, 0);
|
||||||
}
|
}
|
||||||
|
fx *= accum_norm;
|
||||||
|
|
||||||
// ggml_graph_reset (gf);
|
|
||||||
ggml_set_f32 (f->grad, 1.0f);
|
|
||||||
|
|
||||||
ggml_graph_compute(gb, &cplan);
|
|
||||||
|
|
||||||
const float fx = ggml_get_f32_1d(f, 0);
|
|
||||||
opt->loss_after = fx;
|
opt->loss_after = fx;
|
||||||
|
|
||||||
|
|
||||||
|
@ -19373,6 +19395,9 @@ static enum ggml_opt_result linesearch_backtracking(
|
||||||
const float dec = 0.5f;
|
const float dec = 0.5f;
|
||||||
const float inc = 2.1f;
|
const float inc = 2.1f;
|
||||||
|
|
||||||
|
const int n_accum = MAX(1, params->n_gradient_accumulation);
|
||||||
|
const float accum_norm = 1.0f / (float) n_accum;
|
||||||
|
|
||||||
if (*step <= 0.f) {
|
if (*step <= 0.f) {
|
||||||
return GGML_LINESEARCH_INVALID_PARAMETERS;
|
return GGML_LINESEARCH_INVALID_PARAMETERS;
|
||||||
}
|
}
|
||||||
|
@ -19390,12 +19415,6 @@ static enum ggml_opt_result linesearch_backtracking(
|
||||||
dgtest = params->lbfgs.ftol*dginit;
|
dgtest = params->lbfgs.ftol*dginit;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
if (callback) {
|
|
||||||
// LBFG-S does not support learning rate -> ignore learning schedule
|
|
||||||
float sched = 0;
|
|
||||||
callback(callback_data, &sched);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_vec_cpy_f32(nx, x, xp);
|
ggml_vec_cpy_f32(nx, x, xp);
|
||||||
ggml_vec_mad_f32(nx, x, d, *step);
|
ggml_vec_mad_f32(nx, x, d, *step);
|
||||||
|
|
||||||
|
@ -19403,14 +19422,22 @@ static enum ggml_opt_result linesearch_backtracking(
|
||||||
{
|
{
|
||||||
ggml_opt_set_params(np, ps, x);
|
ggml_opt_set_params(np, ps, x);
|
||||||
|
|
||||||
//ggml_graph_reset (gf);
|
*fx = 0;
|
||||||
ggml_set_f32 (f->grad, 1.0f);
|
memset(g, 0, sizeof(float)*nx);
|
||||||
|
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
||||||
|
if (callback) {
|
||||||
|
// LBFG-S does not support learning rate -> ignore learning schedule
|
||||||
|
float sched = 0;
|
||||||
|
callback(callback_data, accum_step, &sched);
|
||||||
|
}
|
||||||
|
// ggml_graph_reset (gf);
|
||||||
|
ggml_set_f32 (f->grad, 1.0f);
|
||||||
|
ggml_graph_compute(gb, cplan);
|
||||||
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
||||||
|
*fx += ggml_get_f32_1d(f, 0);
|
||||||
|
}
|
||||||
|
*fx *= accum_norm;
|
||||||
|
|
||||||
ggml_graph_compute(gb, cplan);
|
|
||||||
|
|
||||||
ggml_opt_get_grad(np, ps, g);
|
|
||||||
|
|
||||||
*fx = ggml_get_f32_1d(f, 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
++count;
|
++count;
|
||||||
|
@ -19512,6 +19539,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
||||||
|
|
||||||
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values
|
||||||
|
|
||||||
|
const int n_accum = MAX(1, params.n_gradient_accumulation);
|
||||||
|
const float accum_norm = 1.0f / (float) n_accum;
|
||||||
|
|
||||||
float fx = 0.0f; // cost function value
|
float fx = 0.0f; // cost function value
|
||||||
float xnorm = 0.0f; // ||x||
|
float xnorm = 0.0f; // ||x||
|
||||||
float gnorm = 0.0f; // ||g||
|
float gnorm = 0.0f; // ||g||
|
||||||
|
@ -19525,24 +19555,25 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
||||||
float * lm_s = opt->lbfgs.lms->data;
|
float * lm_s = opt->lbfgs.lms->data;
|
||||||
float * lm_y = opt->lbfgs.lmy->data;
|
float * lm_y = opt->lbfgs.lmy->data;
|
||||||
|
|
||||||
if (callback) {
|
|
||||||
// LBFG-S does not support learning rate -> ignore learning schedule
|
|
||||||
float sched = 0;
|
|
||||||
callback(callback_data, &sched);
|
|
||||||
}
|
|
||||||
|
|
||||||
// evaluate the function value and its gradient
|
// evaluate the function value and its gradient
|
||||||
{
|
{
|
||||||
ggml_opt_set_params(np, ps, x);
|
ggml_opt_set_params(np, ps, x);
|
||||||
|
|
||||||
//ggml_graph_reset (gf);
|
fx = 0;
|
||||||
ggml_set_f32 (f->grad, 1.0f);
|
memset(g, 0, sizeof(float)*nx);
|
||||||
|
for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
|
||||||
ggml_graph_compute(gb, &cplan);
|
if (callback) {
|
||||||
|
// LBFG-S does not support learning rate -> ignore learning schedule
|
||||||
ggml_opt_get_grad(np, ps, g);
|
float sched = 0;
|
||||||
|
callback(callback_data, accum_step, &sched);
|
||||||
fx = ggml_get_f32_1d(f, 0);
|
}
|
||||||
|
// ggml_graph_reset (gf);
|
||||||
|
ggml_set_f32 (f->grad, 1.0f);
|
||||||
|
ggml_graph_compute(gb, &cplan);
|
||||||
|
ggml_opt_acc_grad(np, ps, g, accum_norm);
|
||||||
|
fx += ggml_get_f32_1d(f, 0);
|
||||||
|
}
|
||||||
|
fx *= accum_norm;
|
||||||
|
|
||||||
opt->loss_before = fx;
|
opt->loss_before = fx;
|
||||||
opt->loss_after = fx;
|
opt->loss_after = fx;
|
||||||
|
@ -19729,6 +19760,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
||||||
.print_forward_graph = true,
|
.print_forward_graph = true,
|
||||||
.print_backward_graph = true,
|
.print_backward_graph = true,
|
||||||
|
|
||||||
|
.n_gradient_accumulation = 1,
|
||||||
|
|
||||||
.adam = {
|
.adam = {
|
||||||
.n_iter = 10000,
|
.n_iter = 10000,
|
||||||
.sched = 1.000f,
|
.sched = 1.000f,
|
||||||
|
@ -19757,6 +19790,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
||||||
.print_forward_graph = true,
|
.print_forward_graph = true,
|
||||||
.print_backward_graph = true,
|
.print_backward_graph = true,
|
||||||
|
|
||||||
|
.n_gradient_accumulation = 1,
|
||||||
|
|
||||||
.lbfgs = {
|
.lbfgs = {
|
||||||
.m = 6,
|
.m = 6,
|
||||||
.n_iter = 100,
|
.n_iter = 100,
|
||||||
|
@ -19790,7 +19825,7 @@ GGML_API void ggml_opt_init(
|
||||||
if (opt->ctx == NULL) {
|
if (opt->ctx == NULL) {
|
||||||
struct ggml_init_params ctx_opt_params;
|
struct ggml_init_params ctx_opt_params;
|
||||||
if (opt->params.type == GGML_OPT_ADAM) {
|
if (opt->params.type == GGML_OPT_ADAM) {
|
||||||
ctx_opt_params.mem_size = GGML_MEM_ALIGN*2 + ggml_tensor_overhead()*2 + ggml_type_size(GGML_TYPE_F32)*nx*2;
|
ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3;
|
||||||
if (opt->params.past > 0) {
|
if (opt->params.past > 0) {
|
||||||
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
|
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
|
||||||
}
|
}
|
||||||
|
@ -19808,6 +19843,7 @@ GGML_API void ggml_opt_init(
|
||||||
switch (opt->params.type) {
|
switch (opt->params.type) {
|
||||||
case GGML_OPT_ADAM:
|
case GGML_OPT_ADAM:
|
||||||
{
|
{
|
||||||
|
opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||||
opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||||
opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||||
opt->adam.pf = params.past > 0
|
opt->adam.pf = params.past > 0
|
||||||
|
|
5
ggml.h
5
ggml.h
|
@ -1708,7 +1708,7 @@ extern "C" {
|
||||||
GGML_LINESEARCH_INVALID_PARAMETERS,
|
GGML_LINESEARCH_INVALID_PARAMETERS,
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef void (*ggml_opt_callback)(void * data, float * sched);
|
typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched);
|
||||||
|
|
||||||
// optimization parameters
|
// optimization parameters
|
||||||
//
|
//
|
||||||
|
@ -1739,6 +1739,8 @@ extern "C" {
|
||||||
bool print_forward_graph;
|
bool print_forward_graph;
|
||||||
bool print_backward_graph;
|
bool print_backward_graph;
|
||||||
|
|
||||||
|
int n_gradient_accumulation;
|
||||||
|
|
||||||
// ADAM parameters
|
// ADAM parameters
|
||||||
struct {
|
struct {
|
||||||
int n_iter;
|
int n_iter;
|
||||||
|
@ -1784,6 +1786,7 @@ extern "C" {
|
||||||
float loss_after;
|
float loss_after;
|
||||||
|
|
||||||
struct {
|
struct {
|
||||||
|
struct ggml_tensor * g; // current gradient
|
||||||
struct ggml_tensor * m; // first moment
|
struct ggml_tensor * m; // first moment
|
||||||
struct ggml_tensor * v; // second moment
|
struct ggml_tensor * v; // second moment
|
||||||
struct ggml_tensor * pf; // past function values
|
struct ggml_tensor * pf; // past function values
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue