diff --git a/common/train.cpp b/common/train.cpp index a1e35e5a3..c2b3f036b 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -17,6 +17,28 @@ struct random_uniform_distribution { std::uniform_real_distribution rd; }; +struct train_state * init_train_state(int seed) { + struct train_state * state = (struct train_state *) malloc(sizeof(struct train_state)); + memset(state, 0, sizeof(struct train_state)); + state->shuffle_rng_state_current = ""; + state->shuffle_rng_state_next = ""; + + state->opt = (struct ggml_opt_context *) malloc(sizeof(struct ggml_opt_context)); + memset(state->opt, 0, sizeof(struct ggml_opt_context)); + state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM); + + return state; +} + +void free_train_state(struct train_state * state) { + free(state->opt); + free(state); +} + +struct ggml_opt_context * get_train_state_opt(struct train_state * state) { + return state->opt; +} + struct random_normal_distribution * init_random_normal_distribution(int seed, float mean, float std, float min, float max) { struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution)); rnd->gen = std::mt19937(seed); @@ -472,6 +494,20 @@ static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"; static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"; +static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"; +static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"; +static const char * LLM_KV_TRAINING_TYPE = "training.type"; +static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; +static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; +static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; +static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; +static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count"; +static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash"; +static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash"; +static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state"; +static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count"; +static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample"; + #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ { \ const std::string skey(key); \ @@ -613,6 +649,59 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * } } +bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) { + if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) >= 0) { + return false; + } + + uint32_t file_version; + GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); + GGML_ASSERT(file_version <= 1); + + std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA; + GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); + GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA); + + if (file_version == 0) { + + GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); + GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); + + } else if (file_version == 1) { + + GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT); + GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT); + GGUF_GET_KEY(fctx, train->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT); + + GGUF_GET_KEY(fctx, train->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH); + GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE); + GGUF_GET_KEY(fctx, train->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, train->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE); + } + + load_opt_context_gguf(fctx, f_ggml_ctx, train->opt); + return true; +} + +void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) { + gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1); + gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, train->train_epochs); + + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash); + gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, train->shuffle_rng_state_current.c_str()); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count); + gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) train->shuffle_next_sample); + + save_opt_context_gguf(fctx, train->opt); +} + + struct llama_file { // use FILE * so we don't have to re-open the file to mmap FILE * fp; diff --git a/common/train.h b/common/train.h index 9d629beb7..54edd0f4a 100644 --- a/common/train.h +++ b/common/train.h @@ -9,6 +9,26 @@ #include "ggml.h" #include "llama.h" +typedef std::string mt19937_state; + +struct train_state { + struct ggml_opt_context * opt; + + uint64_t train_its; + uint64_t train_samples; + uint64_t train_tokens; + uint64_t train_epochs; + + size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes) + mt19937_state shuffle_rng_state_current; + mt19937_state shuffle_rng_state_next; + size_t shuffle_sample_count; + size_t shuffle_next_sample; +}; + +struct train_state * init_train_state(int seed); +void free_train_state(struct train_state * state); + struct random_normal_distribution; struct random_uniform_distribution; @@ -58,7 +78,6 @@ int64_t get_example_targets_batch( bool separate_with_bos, bool fill_with_next_samples); -typedef std::string mt19937_state; void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state); mt19937_state mt19937_get_state(const std::mt19937& rng); @@ -111,3 +130,6 @@ void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, co void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt); void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt); +bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train); +void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train); + diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index ce6f28bad..58e96f186 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -148,34 +148,9 @@ struct my_llama_lora { struct ggml_tensor * output_b; std::vector layers; - - uint64_t train_its = 0; - uint64_t train_samples = 0; - uint64_t train_tokens = 0; - uint64_t train_epochs = 0; - - size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes) - std::string shuffle_rng_state_current; - std::string shuffle_rng_state_next; - size_t shuffle_sample_count; - size_t shuffle_next_sample; }; // gguf constants -static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"; -static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"; -static const char * LLM_KV_TRAINING_TYPE = "training.type"; -static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; -static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; -static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; -static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; -static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count"; -static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash"; -static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash"; -static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state"; -static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count"; -static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample"; - static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd"; static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm"; static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output"; @@ -336,10 +311,6 @@ static void init_lora(const struct my_llama_model * model, struct my_llama_lora const uint32_t n_vocab = model->hparams.n_vocab; const uint32_t n_ff = model->hparams.n_ff; - lora->train_its = 0; - lora->train_samples = 0; - lora->train_tokens = 0; - std::vector tn_buf; tn_buf.resize(GGML_MAX_NAME); auto tn = [&tn_buf](const char * key, const char * suffix) -> const char * { @@ -869,8 +840,6 @@ static void load_default_lora_params_from_base_model(const char * fn_base_model, gguf_free(fctx); } - - static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora) { // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read @@ -1021,58 +990,17 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod } } -static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt) { +static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) { load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora); - - uint32_t file_version; - GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); - GGML_ASSERT(file_version <= 1); - - std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA; - GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); - GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA); - - if (file_version == 0) { - - GGUF_GET_KEY(fctx, lora->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); - GGUF_GET_KEY(fctx, lora->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, lora->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); - - } else if (file_version == 1) { - - GGUF_GET_KEY(fctx, lora->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT); - GGUF_GET_KEY(fctx, lora->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, lora->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT); - GGUF_GET_KEY(fctx, lora->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT); - - GGUF_GET_KEY(fctx, lora->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH); - GGUF_GET_KEY(fctx, lora->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE); - GGUF_GET_KEY(fctx, lora->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, lora->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE); - } - - load_opt_context_gguf(fctx, f_ggml_ctx, opt); + load_train_state_gguf(fctx, f_ggml_ctx, train); } -static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt) { +static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) { save_llama_lora_gguf(fctx, model, lora); - - gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1); - gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, lora->train_its); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, lora->train_samples); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, lora->train_tokens); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, lora->train_epochs); - - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) lora->shuffle_samples_hash); - gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, lora->shuffle_rng_state_current.c_str()); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) lora->shuffle_sample_count); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) lora->shuffle_next_sample); - - save_opt_context_gguf(fctx, opt); + save_train_state_gguf(fctx, train); } -static bool load_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt) { +static bool load_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) { struct ggml_context * f_ggml_ctx; struct gguf_init_params params; params.no_alloc = false; @@ -1082,19 +1010,19 @@ static bool load_checkpoint_lora_file(const char * filename, struct my_llama_mod return false; } - load_checkpoint_lora_gguf(fctx, f_ggml_ctx, model, lora, opt); + load_checkpoint_lora_gguf(fctx, f_ggml_ctx, model, lora, train); gguf_free(fctx); return true; } -static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt, const char * pattern_it, int iteration, const char * latest) { +static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train, const char * pattern_it, int iteration, const char * latest) { std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); std::string fn = replace_str(filename, pattern_it, sit.c_str()); printf("%s: saving to %s\n", __func__, fn.c_str()); struct gguf_context * fctx = gguf_init_empty(); - save_checkpoint_lora_gguf(fctx, model, lora, opt); + save_checkpoint_lora_gguf(fctx, model, lora, train); // write file const bool only_meta = false; @@ -1896,30 +1824,31 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par } struct opt_callback_data { - struct train_params * params; - struct ggml_opt_context * opt; - struct my_llama_model * model; - struct my_llama_lora * lora; - struct llama_context * lctx; - int last_save_iter; - llama_token * tokens_data; - size_t tokens_size; - size_t * samples_begin; - size_t * samples_size; - size_t * shuffled_samples_begin; - size_t * shuffled_samples_size; - size_t samples_count; - struct ggml_tensor * tokens_input; - struct ggml_tensor * target_probs; - int first_iter; - int64_t last_time; - double millis_per_iter; + struct train_params * params; + struct train_state * train; + struct my_llama_model * model; + struct my_llama_lora * lora; + struct llama_context * lctx; + int last_save_iter; + llama_token * tokens_data; + size_t tokens_size; + size_t * samples_begin; + size_t * samples_size; + size_t * shuffled_samples_begin; + size_t * shuffled_samples_size; + size_t samples_count; + struct ggml_tensor * tokens_input; + struct ggml_tensor * target_probs; + int first_iter; + int64_t last_time; + double millis_per_iter; }; static void opt_callback(void * vdata, int accum_step, float * sched) { - struct opt_callback_data * data = (struct opt_callback_data *) vdata; - struct train_params * params = data->params; - struct ggml_opt_context * opt = data->opt; + struct opt_callback_data * data = (struct opt_callback_data *) vdata; + struct train_params * params = data->params; + struct train_state * train = data->train; + struct ggml_opt_context * opt = train->opt; int n_batch = params->n_batch; int n_ctx = params->n_ctx; @@ -1948,13 +1877,13 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); if (save_now) { int new_iters = opt->iter - data->last_save_iter; - data->lora->train_its += new_iters; - data->lora->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; - data->lora->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; + train->train_its += new_iters; + train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; if (strlen(params->fn_checkpoint_out) > 0) { - save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, opt->iter, params->fn_latest); - save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, -1, params->fn_latest); + save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, opt->iter, params->fn_latest); + save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, -1, params->fn_latest); } if (strlen(params->fn_lora_out) > 0) { save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest); @@ -1980,7 +1909,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { if (impr_plot > 0) impr_plot = 0; if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0; printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f", - __func__, opt->iter, std::min(1+data->lora->shuffle_next_sample, data->lora->shuffle_sample_count), data->lora->shuffle_sample_count, + __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count, *sched, opt->loss_after); @@ -2006,7 +1935,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { data->lctx, data->tokens_input, data->target_probs, - data->lora->shuffle_next_sample, + train->shuffle_next_sample, data->shuffled_samples_begin, data->shuffled_samples_size, data->samples_count, @@ -2016,21 +1945,21 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { params->separate_with_bos, params->fill_with_next_samples); - data->lora->shuffle_next_sample += used_samples; + train->shuffle_next_sample += used_samples; - if (data->lora->shuffle_next_sample >= data->lora->shuffle_sample_count) { - ++data->lora->train_epochs; - printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) data->lora->train_epochs); + if (train->shuffle_next_sample >= train->shuffle_sample_count) { + ++train->train_epochs; + printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs); // note: we may have used some samples from the current shuffling more than once - data->lora->shuffle_rng_state_current = data->lora->shuffle_rng_state_next; - data->lora->shuffle_rng_state_next = shuffle_samples( - data->lora->shuffle_rng_state_current, + train->shuffle_rng_state_current = train->shuffle_rng_state_next; + train->shuffle_rng_state_next = shuffle_samples( + train->shuffle_rng_state_current, data->shuffled_samples_begin, data->shuffled_samples_size, data->samples_begin, data->samples_size, data->samples_count); - data->lora->shuffle_next_sample = 0; + train->shuffle_next_sample = 0; } } @@ -2091,10 +2020,9 @@ int main(int argc, char ** argv) { init_model(lmodel, &model, params.n_ctx); struct my_llama_lora lora; - struct ggml_opt_context* opt = (struct ggml_opt_context*)alloca(sizeof(struct ggml_opt_context)); - memset(opt, 0, sizeof(struct ggml_opt_context)); - opt->ctx = NULL; + struct train_state * train = init_train_state(params.seed); + struct ggml_opt_context * opt = train->opt; load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams); @@ -2157,7 +2085,7 @@ int main(int argc, char ** argv) { ggml_allocr * alloc = NULL; printf("%s: init model\n", __func__); - bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, opt); + bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, train); if (existed) { // overwrite last n_ctx with user provided n_ctx @@ -2203,13 +2131,13 @@ int main(int argc, char ** argv) { print_params(&model.hparams); print_lora_params(&lora.hparams); - printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) lora.train_its); - printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) lora.train_samples); - printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) lora.train_tokens); - printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) lora.train_epochs); + printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) train->train_its); + printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples); + printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens); + printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs); printf("%s: max_lora_size = %zu bytes (%.1f MB)\n", __func__, lora.data.size(), (float) lora.data.size() / (1024.0f*1024.0f)); printf("%s: max_opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f)); - opt->iter = lora.train_its; + opt->iter = train->train_its; if (params.only_write_lora) { if (strlen(params.fn_lora_out) > 0) { @@ -2368,25 +2296,25 @@ int main(int argc, char ** argv) { printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens); size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); - const bool changed_train_data = (shuffle_samples_hash != lora.shuffle_samples_hash) || (lora.shuffle_sample_count != train_samples_size.size()); + const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size()); if (changed_train_data) { printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__); } if (params.force_reshuffle) { printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__); } - if ((lora.shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { - lora.shuffle_rng_state_current = mt19937_seed_to_state(params.seed); - lora.shuffle_sample_count = train_samples_size.size(); - lora.shuffle_next_sample = 0; - lora.shuffle_samples_hash = shuffle_samples_hash; + if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { + train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed); + train->shuffle_sample_count = train_samples_size.size(); + train->shuffle_next_sample = 0; + train->shuffle_samples_hash = shuffle_samples_hash; } std::vector train_shuffled_samples_begin; std::vector train_shuffled_samples_size; train_shuffled_samples_begin.resize(train_samples_begin.size()); train_shuffled_samples_size.resize(train_samples_size.size()); - lora.shuffle_rng_state_next = shuffle_samples( - lora.shuffle_rng_state_current, + train->shuffle_rng_state_next = shuffle_samples( + train->shuffle_rng_state_current, train_shuffled_samples_begin.data(), train_shuffled_samples_size.data(), train_samples_begin.data(), @@ -2397,7 +2325,7 @@ int main(int argc, char ** argv) { struct opt_callback_data opt_cb_data; opt_cb_data.params = ¶ms; - opt_cb_data.opt = opt; + opt_cb_data.train = train; opt_cb_data.model = &model; opt_cb_data.lora = &lora; opt_cb_data.lctx = lctx; @@ -2442,13 +2370,13 @@ int main(int argc, char ** argv) { int new_iters = opt->iter - opt_cb_data.last_save_iter; if (new_iters > 0) { - lora.train_its += new_iters; - lora.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; - lora.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; + train->train_its += new_iters; + train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; if (strlen(params.fn_checkpoint_out) > 0) { - save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, opt, params.pattern_fn_it, opt->iter, params.fn_latest); - save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, opt, params.pattern_fn_it, -1, params.fn_latest); + save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, opt->iter, params.fn_latest); + save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, -1, params.fn_latest); } if (strlen(params.fn_lora_out) > 0) { save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest); @@ -2458,6 +2386,7 @@ int main(int argc, char ** argv) { } ggml_free(opt->ctx); + free_train_state(train); ggml_free(lora.ctx); llama_free(lctx); llama_free_model(lmodel); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 63edcf9ef..bead80843 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -65,34 +65,8 @@ struct my_llama_model { struct ggml_tensor * output; std::vector layers; - - uint64_t train_its = 0; - uint64_t train_samples = 0; - uint64_t train_tokens = 0; - uint64_t train_epochs = 0; - - size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes) - std::string shuffle_rng_state_current; - std::string shuffle_rng_state_next; - size_t shuffle_sample_count; - size_t shuffle_next_sample; }; -// gguf constants -static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"; -static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"; -static const char * LLM_KV_TRAINING_TYPE = "training.type"; -static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; -static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; -static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; -static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; -static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count"; -static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash"; -static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash"; -static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state"; -static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count"; -static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample"; - // gguf constants (sync with gguf.py) static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture"; @@ -152,11 +126,6 @@ static void init_model(struct my_llama_model * model) { struct ggml_context * ctx = model->ctx; - model->train_its = 0; - model->train_samples = 0; - model->train_tokens = 0; - model->train_epochs = 0; - std::vector tn_buf; tn_buf.resize(GGML_MAX_NAME); auto tn = [&tn_buf](const char * key) -> const char * { @@ -685,62 +654,19 @@ static void save_llama_model_file(const char * filename, const char * fn_vocab_m gguf_free(fctx); } -static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) { +static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct train_state * train) { load_llama_model_gguf(fctx, f_ggml_ctx, model); - - if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) >= 0) { - uint32_t file_version = 0xFFFFFFFFu; - GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); - GGML_ASSERT(file_version <= 1); - - std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL; - GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); - GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL); - - if (file_version == 0) { - - GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); - GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); - - } else if (file_version == 1) { - - GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT); - GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT); - GGUF_GET_KEY(fctx, model->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT); - - GGUF_GET_KEY(fctx, model->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH); - GGUF_GET_KEY(fctx, model->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE); - GGUF_GET_KEY(fctx, model->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, model->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE); - } - - load_opt_context_gguf(fctx, f_ggml_ctx, opt); - } else { + if (!load_train_state_gguf(fctx, f_ggml_ctx, train)) { printf("%s: loaded llama model as checkpoint\n", __func__); } } -static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) { +static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) { save_llama_model_gguf(fctx, fn_vocab_model, model); - - gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1); - gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, model->train_epochs); - - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) model->shuffle_samples_hash); - gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, model->shuffle_rng_state_current.c_str()); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) model->shuffle_sample_count); - gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) model->shuffle_next_sample); - - save_opt_context_gguf(fctx, opt); + save_train_state_gguf(fctx, train); } -static bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct ggml_opt_context * opt) { +static bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct train_state * train) { struct ggml_context * f_ggml_ctx; struct gguf_init_params params; params.no_alloc = false; @@ -750,18 +676,18 @@ static bool load_checkpoint_file(const char * filename, struct my_llama_model * return false; } - load_checkpoint_gguf(fctx, f_ggml_ctx, model, opt); + load_checkpoint_gguf(fctx, f_ggml_ctx, model, train); return true; } -static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt, const char * pattern_it, int iteration, const char * latest) { +static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train, const char * pattern_it, int iteration, const char * latest) { std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); std::string fn = replace_str(filename, pattern_it, sit.c_str()); printf("%s: saving to %s\n", __func__, fn.c_str()); struct gguf_context * fctx = gguf_init_empty(); - save_checkpoint_gguf(fctx, fn_vocab_model, model, opt); + save_checkpoint_gguf(fctx, fn_vocab_model, model, train); // write file const bool only_meta = false; @@ -1295,30 +1221,31 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par } struct opt_callback_data { - struct train_params * params; - struct ggml_opt_context * opt; - struct my_llama_model * model; - struct llama_context * lctx; - int last_save_iter; - llama_token * tokens_data; - size_t tokens_size; - size_t * samples_begin; - size_t * samples_size; - size_t * shuffled_samples_begin; - size_t * shuffled_samples_size; - size_t samples_count; - struct ggml_tensor * tokens_input; - struct ggml_tensor * target_logits; - struct ggml_tensor * target_probs; - int first_iter; - int64_t last_time; - double millis_per_iter; + struct train_params * params; + struct train_state * train; + struct my_llama_model * model; + struct llama_context * lctx; + int last_save_iter; + llama_token * tokens_data; + size_t tokens_size; + size_t * samples_begin; + size_t * samples_size; + size_t * shuffled_samples_begin; + size_t * shuffled_samples_size; + size_t samples_count; + struct ggml_tensor * tokens_input; + struct ggml_tensor * target_logits; + struct ggml_tensor * target_probs; + int first_iter; + int64_t last_time; + double millis_per_iter; }; static void opt_callback(void * vdata, int accum_step, float * sched) { - struct opt_callback_data * data = (struct opt_callback_data *) vdata; - struct train_params * params = data->params; - struct ggml_opt_context * opt = data->opt; + struct opt_callback_data * data = (struct opt_callback_data *) vdata; + struct train_params * params = data->params; + struct train_state * train = data->train; + struct ggml_opt_context * opt = train->opt; int n_batch = params->n_batch; int n_ctx = params->n_ctx; @@ -1347,13 +1274,13 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); if (save_now) { int new_iters = opt->iter - data->last_save_iter; - data->model->train_its += new_iters; - data->model->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; - data->model->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; + train->train_its += new_iters; + train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; if (strlen(params->fn_checkpoint_out) > 0) { - save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest); - save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, -1, params->fn_latest); + save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, opt->iter, params->fn_latest); + save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, -1, params->fn_latest); } if (strlen(params->fn_model_out) > 0) { @@ -1380,7 +1307,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { if (impr_plot > 0) impr_plot = 0; if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0; printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f", - __func__, opt->iter, std::min(1+data->model->shuffle_next_sample, data->model->shuffle_sample_count), data->model->shuffle_sample_count, + __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count, *sched, opt->loss_after); @@ -1406,7 +1333,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { data->lctx, data->tokens_input, data->target_probs, - data->model->shuffle_next_sample, + train->shuffle_next_sample, data->shuffled_samples_begin, data->shuffled_samples_size, data->samples_count, @@ -1416,21 +1343,21 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { params->separate_with_bos, params->fill_with_next_samples); - data->model->shuffle_next_sample += used_samples; + train->shuffle_next_sample += used_samples; - if (data->model->shuffle_next_sample >= data->model->shuffle_sample_count) { - ++data->model->train_epochs; - printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) data->model->train_epochs); + if (train->shuffle_next_sample >= train->shuffle_sample_count) { + ++train->train_epochs; + printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs); // note: we may have used some samples from the current shuffling more than once - data->model->shuffle_rng_state_current = data->model->shuffle_rng_state_next; - data->model->shuffle_rng_state_next = shuffle_samples( - data->model->shuffle_rng_state_current, + train->shuffle_rng_state_current = train->shuffle_rng_state_next; + train->shuffle_rng_state_next = shuffle_samples( + train->shuffle_rng_state_current, data->shuffled_samples_begin, data->shuffled_samples_size, data->samples_begin, data->samples_size, data->samples_count); - data->model->shuffle_next_sample = 0; + train->shuffle_next_sample = 0; } } @@ -1480,8 +1407,8 @@ int main(int argc, char ** argv) { int n_vocab = model.hparams.n_vocab; int n_batch = params.n_batch; - struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); - memset(opt, 0, sizeof(struct ggml_opt_context)); + struct train_state * train = init_train_state(params.seed); + struct ggml_opt_context * opt = train->opt; struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); opt_params_adam.print_forward_graph = false; @@ -1505,7 +1432,7 @@ int main(int argc, char ** argv) { opt->params = opt_params_adam; printf("%s: init model\n", __func__); - bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt); + bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, train); if (!existed) { init_model(&model); } @@ -1513,7 +1440,7 @@ int main(int argc, char ** argv) { opt->params = opt_params_adam; - opt->iter = model.train_its; + opt->iter = train->train_its; printf("%s: opt iter %d\n", __func__, opt->iter); bool from_scratch = !existed; @@ -1555,25 +1482,25 @@ int main(int argc, char ** argv) { printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size()); size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); - const bool changed_train_data = (shuffle_samples_hash != model.shuffle_samples_hash) || (model.shuffle_sample_count != train_samples_size.size()); + const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size()); if (changed_train_data) { printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__); } if (params.force_reshuffle) { printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__); } - if ((model.shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { - model.shuffle_rng_state_current = mt19937_seed_to_state(params.seed); - model.shuffle_sample_count = train_samples_size.size(); - model.shuffle_next_sample = 0; - model.shuffle_samples_hash = shuffle_samples_hash; + if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { + train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed); + train->shuffle_sample_count = train_samples_size.size(); + train->shuffle_next_sample = 0; + train->shuffle_samples_hash = shuffle_samples_hash; } std::vector train_shuffled_samples_begin; std::vector train_shuffled_samples_size; train_shuffled_samples_begin.resize(train_samples_begin.size()); train_shuffled_samples_size.resize(train_samples_size.size()); - model.shuffle_rng_state_next = shuffle_samples( - model.shuffle_rng_state_current, + train->shuffle_rng_state_next = shuffle_samples( + train->shuffle_rng_state_current, train_shuffled_samples_begin.data(), train_shuffled_samples_size.data(), train_samples_begin.data(), @@ -1583,7 +1510,7 @@ int main(int argc, char ** argv) { struct opt_callback_data opt_cb_data; opt_cb_data.params = ¶ms; - opt_cb_data.opt = opt; + opt_cb_data.train = train; opt_cb_data.model = &model; opt_cb_data.lctx = lctx; opt_cb_data.last_save_iter = opt->iter; @@ -1594,9 +1521,9 @@ int main(int argc, char ** argv) { opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data(); opt_cb_data.samples_count = train_samples_size.size(); - opt_cb_data.tokens_input = NULL; - opt_cb_data.target_logits = NULL; - opt_cb_data.target_probs = NULL; + opt_cb_data.tokens_input = NULL; + opt_cb_data.target_logits = NULL; + opt_cb_data.target_probs = NULL; opt_cb_data.first_iter = opt->iter; opt_cb_data.last_time = ggml_time_ms(); opt_cb_data.millis_per_iter = 0.0; @@ -1672,9 +1599,9 @@ int main(int argc, char ** argv) { size_t used_mem_after_opt = ggml_used_mem(ctx0); int n_iter = params.adam_n_iter; - model.train_its = opt->iter; - model.train_samples += n_batch * n_iter; - model.train_tokens += n_batch * n_tokens * n_iter; + train->train_its = opt->iter; + train->train_samples += n_batch * n_iter; + train->train_tokens += n_batch * n_tokens * n_iter; if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { printf("Example %d, opt iter %d\n", ex, opt->iter); @@ -1693,13 +1620,13 @@ int main(int argc, char ** argv) { printf("%s: total training time=%f seconds\n", __func__, dd); int new_iters = opt->iter - opt_cb_data.last_save_iter; - model.train_its += new_iters; - model.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; - model.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; + train->train_its += new_iters; + train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; if (params.n_examples > 0) { - save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, opt->iter, params.fn_latest); - save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, -1, params.fn_latest); + save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, opt->iter, params.fn_latest); + save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, -1, params.fn_latest); } if (strlen(params.fn_model_out) > 0) { @@ -1715,6 +1642,7 @@ int main(int argc, char ** argv) { delete[] compute_addr; delete[] compute_buf_0; + free_train_state(train); ggml_free(model.ctx); llama_free(lctx); llama_free_model(lmodel);