move common opt_callback into common/train

This commit is contained in:
xaedes 2023-09-16 18:51:16 +02:00
parent e9758ae1d2
commit bef1e97875
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
4 changed files with 146 additions and 287 deletions

View file

@ -1329,3 +1329,118 @@ void finish_processing_train_args(struct train_params_common * params) {
process_escapes(params->sample_start);
}
}
void train_opt_callback(void * vdata, int accum_step, float * sched) {
struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata;
struct train_params_common * params = data->params;
struct train_state * train = data->train;
struct ggml_opt_context * opt = train->opt;
int n_batch = params->n_batch;
int n_ctx = params->n_ctx;
if (accum_step == 0) {
// time measurement
int64_t now = ggml_time_ms();
if (now > data->last_time && opt->iter > data->first_iter) {
double dt = (double) (now - data->last_time);
if (data->millis_per_iter == 0.0) {
data->millis_per_iter = dt;
} else {
const double gain = 0.7;
data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
}
}
double remaining_millis = 0.0;
if (data->millis_per_iter > 0.0) {
const int n_iter = params->adam_n_iter;
const int done_iter = opt->iter - data->first_iter;
const int remaining_iter = n_iter - done_iter;
remaining_millis = remaining_iter * data->millis_per_iter;
}
// file saving
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
if (save_now) {
int new_iters = opt->iter - data->last_save_iter;
train->train_its += new_iters;
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (data->save_cb) {
data->save_cb(data->save_data, train);
}
data->last_save_iter = opt->iter;
}
// exclude file saving from time measurement, by measuring last_time after saving
data->last_time = ggml_time_ms();
*sched = learning_schedule(
opt->iter,
params->warmup,
params->cos_decay_steps,
params->adam_alpha,
params->adam_min_alpha,
params->cos_decay_min,
params->cos_decay_restart,
params->enable_restart);
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
if (impr_plot > 0) impr_plot = 0;
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
__func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
*sched, opt->loss_after);
if (data->millis_per_iter > 0) {
printf(" dt=");
print_duration(data->millis_per_iter);
printf(" eta=");
print_duration(remaining_millis);
}
float improvement = opt->loss_before - opt->loss_after;
const float plot_scale = 10.0f;
int bar_len = (int)(1 + improvement*plot_scale + 0.5);
printf(" |");
for (int i=0; i<bar_len; ++i) {
printf("-");
}
printf(">");
printf("\n");
}
int64_t used_samples = get_example_targets_batch(
data->lctx,
data->tokens_input,
data->target_probs,
train->shuffle_next_sample,
data->shuffled_samples_begin,
data->shuffled_samples_size,
data->samples_count,
data->tokens_data,
data->tokens_size,
params->separate_with_eos,
params->separate_with_bos,
params->fill_with_next_samples);
train->shuffle_next_sample += used_samples;
if (train->shuffle_next_sample >= train->shuffle_sample_count) {
++train->train_epochs;
printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
// note: we may have used some samples from the current shuffling more than once
train->shuffle_rng_state_current = train->shuffle_rng_state_next;
train->shuffle_rng_state_next = shuffle_samples(
train->shuffle_rng_state_current,
data->shuffled_samples_begin,
data->shuffled_samples_size,
data->samples_begin,
data->samples_size,
data->samples_count);
train->shuffle_next_sample = 0;
}
}

View file

@ -80,6 +80,29 @@ struct train_params_common {
float adam_eps_f;
};
typedef void (*save_train_files_callback)(void * data, struct train_state * train);
struct train_opt_callback_data {
struct train_params_common * params;
struct train_state * train;
save_train_files_callback save_cb;
void * save_data;
struct llama_context * lctx;
int last_save_iter;
llama_token * tokens_data;
size_t tokens_size;
size_t * samples_begin;
size_t * samples_size;
size_t * shuffled_samples_begin;
size_t * shuffled_samples_size;
size_t samples_count;
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_probs;
int first_iter;
int64_t last_time;
double millis_per_iter;
};
struct train_state * init_train_state(int seed);
void free_train_state(struct train_state * state);
@ -195,4 +218,4 @@ void save_train_state_gguf(struct gguf_context * fctx, struct train_state * trai
std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
typedef void (*save_train_files_callback)(void * data, struct train_state * train);
void train_opt_callback(void * vdata, int accum_step, float * sched);

View file

@ -1512,142 +1512,6 @@ static void save_train_files(void * vdata, struct train_state * train) {
}
}
struct opt_callback_data {
struct train_params_common * params;
struct train_state * train;
save_train_files_callback save_cb;
void * save_data;
struct llama_context * lctx;
int last_save_iter;
llama_token * tokens_data;
size_t tokens_size;
size_t * samples_begin;
size_t * samples_size;
size_t * shuffled_samples_begin;
size_t * shuffled_samples_size;
size_t samples_count;
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_probs;
int first_iter;
int64_t last_time;
double millis_per_iter;
};
static void opt_callback(void * vdata, int accum_step, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params_common * params = data->params;
struct train_state * train = data->train;
struct ggml_opt_context * opt = train->opt;
int n_batch = params->n_batch;
int n_ctx = params->n_ctx;
if (accum_step == 0) {
// time measurement
int64_t now = ggml_time_ms();
if (now > data->last_time && opt->iter > data->first_iter) {
double dt = now - data->last_time;
if (data->millis_per_iter == 0.0) {
data->millis_per_iter = dt;
} else {
const double gain = 0.7;
data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
}
}
double remaining_millis = 0.0;
if (data->millis_per_iter > 0.0) {
const int n_iter = params->adam_n_iter;
const int done_iter = opt->iter - data->first_iter;
const int remaining_iter = n_iter - done_iter;
remaining_millis = remaining_iter * data->millis_per_iter;
}
// file saving
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
if (save_now) {
int new_iters = opt->iter - data->last_save_iter;
train->train_its += new_iters;
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (data->save_cb) {
data->save_cb(data->save_data, train);
}
data->last_save_iter = opt->iter;
}
// exclude file saving from time measurement, by measuring last_time after saving
data->last_time = ggml_time_ms();
*sched = learning_schedule(
opt->iter,
params->warmup,
params->cos_decay_steps,
params->adam_alpha,
params->adam_min_alpha,
params->cos_decay_min,
params->cos_decay_restart,
params->enable_restart);
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
if (impr_plot > 0) impr_plot = 0;
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
__func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
*sched, opt->loss_after);
if (data->millis_per_iter > 0) {
printf(" dt=");
print_duration(data->millis_per_iter);
printf(" eta=");
print_duration(remaining_millis);
}
float improvement = opt->loss_before - opt->loss_after;
const float plot_scale = 10.0f;
int bar_len = (int)(1 + improvement*plot_scale + 0.5);
printf(" |");
for (int i=0; i<bar_len; ++i) {
printf("-");
}
printf(">");
printf("\n");
}
int64_t used_samples = get_example_targets_batch(
data->lctx,
data->tokens_input,
data->target_probs,
train->shuffle_next_sample,
data->shuffled_samples_begin,
data->shuffled_samples_size,
data->samples_count,
data->tokens_data,
data->tokens_size,
params->separate_with_eos,
params->separate_with_bos,
params->fill_with_next_samples);
train->shuffle_next_sample += used_samples;
if (train->shuffle_next_sample >= train->shuffle_sample_count) {
++train->train_epochs;
printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
// note: we may have used some samples from the current shuffling more than once
train->shuffle_rng_state_current = train->shuffle_rng_state_next;
train->shuffle_rng_state_next = shuffle_samples(
train->shuffle_rng_state_current,
data->shuffled_samples_begin,
data->shuffled_samples_size,
data->samples_begin,
data->samples_size,
data->samples_count);
train->shuffle_next_sample = 0;
}
}
static int64_t get_parameter_count(struct my_llama_lora* lora) {
int64_t nx = 0;
nx += ggml_nelements(lora->tok_embeddings_a);
@ -2023,7 +1887,7 @@ int main(int argc, char ** argv) {
save_data.model = &model;
save_data.lora = &lora;
struct opt_callback_data opt_cb_data;
struct train_opt_callback_data opt_cb_data;
opt_cb_data.params = &params.common;
opt_cb_data.train = train;
opt_cb_data.save_cb = &save_train_files;
@ -2057,7 +1921,7 @@ int main(int argc, char ** argv) {
int64_t t0 = ggml_time_ms();
ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &opt_callback, (void *) &opt_cb_data);
ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
ggml_free(ctx_work);
ggml_free(ctx_compute);

View file

@ -919,144 +919,6 @@ static void save_train_files(void * vdata, struct train_state * train) {
}
}
struct opt_callback_data {
struct train_params_common * params;
struct train_state * train;
save_train_files_callback save_cb;
void * save_data;
struct llama_context * lctx;
int last_save_iter;
llama_token * tokens_data;
size_t tokens_size;
size_t * samples_begin;
size_t * samples_size;
size_t * shuffled_samples_begin;
size_t * shuffled_samples_size;
size_t samples_count;
struct ggml_tensor * tokens_input;
struct ggml_tensor * target_logits;
struct ggml_tensor * target_probs;
int first_iter;
int64_t last_time;
double millis_per_iter;
};
static void opt_callback(void * vdata, int accum_step, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params_common * params = data->params;
struct train_state * train = data->train;
struct ggml_opt_context * opt = train->opt;
int n_batch = params->n_batch;
int n_ctx = params->n_ctx;
if (accum_step == 0) {
// time measurement
int64_t now = ggml_time_ms();
if (now > data->last_time && opt->iter > data->first_iter) {
double dt = now - data->last_time;
if (data->millis_per_iter == 0.0) {
data->millis_per_iter = dt;
} else {
const double gain = 0.7;
data->millis_per_iter = data->millis_per_iter*(1.0-gain) + dt*gain;
}
}
double remaining_millis = 0.0;
if (data->millis_per_iter > 0.0) {
const int n_iter = params->adam_n_iter;
const int done_iter = opt->iter - data->first_iter;
const int remaining_iter = n_iter - done_iter;
remaining_millis = remaining_iter * data->millis_per_iter;
}
// file saving
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
if (save_now) {
int new_iters = opt->iter - data->last_save_iter;
train->train_its += new_iters;
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (data->save_cb) {
data->save_cb(data->save_data, train);
}
data->last_save_iter = opt->iter;
}
// exclude file saving from time measurement, by measuring last_time after saving
data->last_time = ggml_time_ms();
*sched = learning_schedule(
opt->iter,
params->warmup,
params->cos_decay_steps,
params->adam_alpha,
params->adam_min_alpha,
params->cos_decay_min,
params->cos_decay_restart,
params->enable_restart);
int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f);
if (impr_plot > 0) impr_plot = 0;
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
__func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
*sched, opt->loss_after);
if (data->millis_per_iter > 0) {
printf(" dt=");
print_duration(data->millis_per_iter);
printf(" eta=");
print_duration(remaining_millis);
}
float improvement = opt->loss_before - opt->loss_after;
const float plot_scale = 10.0f;
int bar_len = (int)(1 + improvement*plot_scale + 0.5);
printf(" |");
for (int i=0; i<bar_len; ++i) {
printf("-");
}
printf(">");
printf("\n");
}
int64_t used_samples = get_example_targets_batch(
data->lctx,
data->tokens_input,
data->target_probs,
train->shuffle_next_sample,
data->shuffled_samples_begin,
data->shuffled_samples_size,
data->samples_count,
data->tokens_data,
data->tokens_size,
params->separate_with_eos,
params->separate_with_bos,
params->fill_with_next_samples);
train->shuffle_next_sample += used_samples;
if (train->shuffle_next_sample >= train->shuffle_sample_count) {
++train->train_epochs;
printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
// note: we may have used some samples from the current shuffling more than once
train->shuffle_rng_state_current = train->shuffle_rng_state_next;
train->shuffle_rng_state_next = shuffle_samples(
train->shuffle_rng_state_current,
data->shuffled_samples_begin,
data->shuffled_samples_size,
data->samples_begin,
data->samples_size,
data->samples_count);
train->shuffle_next_sample = 0;
}
}
int main(int argc, char ** argv) {
struct train_params params = get_default_train_params();
@ -1211,7 +1073,7 @@ int main(int argc, char ** argv) {
save_data.fn_latest = params.common.fn_latest;
save_data.model = &model;
struct opt_callback_data opt_cb_data;
struct train_opt_callback_data opt_cb_data;
opt_cb_data.params = &params.common;
opt_cb_data.train = train;
opt_cb_data.save_cb = &save_train_files;
@ -1226,7 +1088,6 @@ int main(int argc, char ** argv) {
opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data();
opt_cb_data.samples_count = train_samples_size.size();
opt_cb_data.tokens_input = NULL;
opt_cb_data.target_logits = NULL;
opt_cb_data.target_probs = NULL;
opt_cb_data.first_iter = opt->iter;
opt_cb_data.last_time = ggml_time_ms();
@ -1246,10 +1107,7 @@ int main(int argc, char ** argv) {
ggml_set_no_alloc(ctx0, false);
// don't use alloc for input tensors, so we can safely fill them with data
//struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
//struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch);
struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
ggml_set_no_alloc(ctx0, (alloc != NULL));
@ -1259,7 +1117,6 @@ int main(int argc, char ** argv) {
}
opt_cb_data.tokens_input = tokens_input;
opt_cb_data.target_logits = target_logits;
opt_cb_data.target_probs = target_probs;
int n_past = 0;
@ -1298,7 +1155,7 @@ int main(int argc, char ** argv) {
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &opt_callback, (void *) &opt_cb_data);
ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
size_t used_mem_after_opt = ggml_used_mem(ctx0);