print time per iteration and estimate remaining time

This commit is contained in:
xaedes 2023-09-01 17:03:36 +02:00
parent 6809eb7de9
commit c32ad44f84
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -2221,8 +2221,34 @@ struct opt_callback_data {
int shuffle_countdown;
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_probs;
int first_iter;
int64_t last_time;
float time_per_iter;
};
void print_duration(float fmillis) {
if (fmillis < 1000.0f) {
printf("%.1fms", fmillis);
return;
}
const int64_t one_sec = 1000;
const int64_t one_min = one_sec * 60;
const int64_t one_hour = one_min * 60;
const int64_t one_day = one_hour * 24;
int64_t millis = fmillis;
int64_t days = millis/one_day;
int64_t hours = (millis - days*one_day)/one_hour;
int64_t minutes = (millis - days*one_day - hours*one_hour)/one_min;
int64_t seconds = (millis - days*one_day - hours*one_hour - minutes*one_min)/one_sec;
if (days > 0) {
printf("%ldd %02ld:%02ld:%02ld", days, hours, minutes, seconds);
} else {
printf("%02ld:%02ld:%02ld", hours, minutes, seconds);
}
}
void opt_callback(void * vdata, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params * params = data->params;
@ -2230,6 +2256,25 @@ void opt_callback(void * vdata, float * sched) {
int n_batch = params->n_batch;
int n_ctx = params->n_ctx;
int64_t now = ggml_time_ms();
if (now > data->last_time) {
float dt = now - data->last_time;
if (data->time_per_iter == 0) {
data->time_per_iter = dt;
} else {
const float gain = 0.7f;
data->time_per_iter = data->time_per_iter*(1.0f-gain) + dt*gain;
}
}
data->last_time = now;
float remaining_time = 0;
if (data->time_per_iter > 0) {
const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter;
const int done_iter = opt->iter - data->first_iter;
const int remaining_iter = n_iter - done_iter;
remaining_time = remaining_iter * data->time_per_iter;
}
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
if (save_now) {
int new_iters = opt->iter - data->last_save_iter;
@ -2262,7 +2307,15 @@ void opt_callback(void * vdata, float * sched) {
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
if (impr_plot > 0) impr_plot = 0;
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
printf("%s: iter=%*d, sched=%f loss=%f ", __func__, 6, opt->iter, *sched, opt->loss_after);
printf("%s: iter=%*d sched=%f loss=%f",
__func__, 6, opt->iter, *sched, opt->loss_after);
if (data->time_per_iter > 0) {
printf(" dt=");
print_duration(data->time_per_iter);
printf(" eta=");
print_duration(remaining_time);
}
float improvement = opt->loss_before - opt->loss_after;
const float plot_scale = 10.0f;
int bar_len = (int)(1 + improvement*plot_scale + 0.5);
@ -2618,6 +2671,9 @@ int main(int argc, char ** argv) {
opt_cb_data.shuffle_countdown = train_samples.size();
opt_cb_data.tokens_input = tokens_input;
opt_cb_data.target_probs = target_probs;
opt_cb_data.first_iter = opt->iter;
opt_cb_data.last_time = ggml_time_ms();
opt_cb_data.time_per_iter = 0;
// measure required memory for work buffer
size_t max_work_size = ggml_graph_plan(gb, params.n_threads).work_size + GGML_OBJECT_SIZE;