move train state into struct train_state

This commit is contained in:
xaedes 2023-09-16 17:08:18 +02:00
parent 9f4b1bf88d
commit a8c8907c62
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
4 changed files with 250 additions and 282 deletions

View file

@ -17,6 +17,28 @@ struct random_uniform_distribution {
std::uniform_real_distribution<float> rd; std::uniform_real_distribution<float> rd;
}; };
struct train_state * init_train_state(int seed) {
struct train_state * state = (struct train_state *) malloc(sizeof(struct train_state));
memset(state, 0, sizeof(struct train_state));
state->shuffle_rng_state_current = "";
state->shuffle_rng_state_next = "";
state->opt = (struct ggml_opt_context *) malloc(sizeof(struct ggml_opt_context));
memset(state->opt, 0, sizeof(struct ggml_opt_context));
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
return state;
}
void free_train_state(struct train_state * state) {
free(state->opt);
free(state);
}
struct ggml_opt_context * get_train_state_opt(struct train_state * state) {
return state->opt;
}
struct random_normal_distribution * init_random_normal_distribution(int seed, float mean, float std, float min, float max) { struct random_normal_distribution * init_random_normal_distribution(int seed, float mean, float std, float min, float max) {
struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution)); struct random_normal_distribution * rnd = (struct random_normal_distribution *) malloc(sizeof(struct random_normal_distribution));
rnd->gen = std::mt19937(seed); rnd->gen = std::mt19937(seed);
@ -472,6 +494,20 @@ static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"; static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"; static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
static const char * LLM_KV_TRAINING_TYPE = "training.type";
static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample";
#define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
{ \ { \
const std::string skey(key); \ const std::string skey(key); \
@ -613,6 +649,59 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context *
} }
} }
bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train) {
if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) >= 0) {
return false;
}
uint32_t file_version;
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
GGML_ASSERT(file_version <= 1);
std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
if (file_version == 0) {
GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
} else if (file_version == 1) {
GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
GGUF_GET_KEY(fctx, train->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, train->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
GGUF_GET_KEY(fctx, train->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
GGUF_GET_KEY(fctx, train->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
GGUF_GET_KEY(fctx, train->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
GGUF_GET_KEY(fctx, train->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, train->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
}
load_opt_context_gguf(fctx, f_ggml_ctx, train->opt);
return true;
}
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, train->train_epochs);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) train->shuffle_samples_hash);
gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, train->shuffle_rng_state_current.c_str());
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) train->shuffle_sample_count);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) train->shuffle_next_sample);
save_opt_context_gguf(fctx, train->opt);
}
struct llama_file { struct llama_file {
// use FILE * so we don't have to re-open the file to mmap // use FILE * so we don't have to re-open the file to mmap
FILE * fp; FILE * fp;

View file

@ -9,6 +9,26 @@
#include "ggml.h" #include "ggml.h"
#include "llama.h" #include "llama.h"
typedef std::string mt19937_state;
struct train_state {
struct ggml_opt_context * opt;
uint64_t train_its;
uint64_t train_samples;
uint64_t train_tokens;
uint64_t train_epochs;
size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes)
mt19937_state shuffle_rng_state_current;
mt19937_state shuffle_rng_state_next;
size_t shuffle_sample_count;
size_t shuffle_next_sample;
};
struct train_state * init_train_state(int seed);
void free_train_state(struct train_state * state);
struct random_normal_distribution; struct random_normal_distribution;
struct random_uniform_distribution; struct random_uniform_distribution;
@ -58,7 +78,6 @@ int64_t get_example_targets_batch(
bool separate_with_bos, bool separate_with_bos,
bool fill_with_next_samples); bool fill_with_next_samples);
typedef std::string mt19937_state;
void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state); void mt19937_set_state(std::mt19937& rng, const mt19937_state& rng_state);
mt19937_state mt19937_get_state(const std::mt19937& rng); mt19937_state mt19937_get_state(const std::mt19937& rng);
@ -111,3 +130,6 @@ void copy_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, co
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt); void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt);
void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt); void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * opt);
bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train);
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train);

View file

@ -148,34 +148,9 @@ struct my_llama_lora {
struct ggml_tensor * output_b; struct ggml_tensor * output_b;
std::vector<my_llama_lora_layer> layers; std::vector<my_llama_lora_layer> layers;
uint64_t train_its = 0;
uint64_t train_samples = 0;
uint64_t train_tokens = 0;
uint64_t train_epochs = 0;
size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes)
std::string shuffle_rng_state_current;
std::string shuffle_rng_state_next;
size_t shuffle_sample_count;
size_t shuffle_next_sample;
}; };
// gguf constants // gguf constants
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
static const char * LLM_KV_TRAINING_TYPE = "training.type";
static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample";
static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd"; static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd";
static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm"; static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm";
static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output"; static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output";
@ -336,10 +311,6 @@ static void init_lora(const struct my_llama_model * model, struct my_llama_lora
const uint32_t n_vocab = model->hparams.n_vocab; const uint32_t n_vocab = model->hparams.n_vocab;
const uint32_t n_ff = model->hparams.n_ff; const uint32_t n_ff = model->hparams.n_ff;
lora->train_its = 0;
lora->train_samples = 0;
lora->train_tokens = 0;
std::vector<char> tn_buf; std::vector<char> tn_buf;
tn_buf.resize(GGML_MAX_NAME); tn_buf.resize(GGML_MAX_NAME);
auto tn = [&tn_buf](const char * key, const char * suffix) -> const char * { auto tn = [&tn_buf](const char * key, const char * suffix) -> const char * {
@ -869,8 +840,6 @@ static void load_default_lora_params_from_base_model(const char * fn_base_model,
gguf_free(fctx); gguf_free(fctx);
} }
static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora) { static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora) {
// NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
@ -1021,58 +990,17 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod
} }
} }
static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt) { static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora); load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora);
load_train_state_gguf(fctx, f_ggml_ctx, train);
uint32_t file_version;
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
GGML_ASSERT(file_version <= 1);
std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
if (file_version == 0) {
GGUF_GET_KEY(fctx, lora->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
GGUF_GET_KEY(fctx, lora->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, lora->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
} else if (file_version == 1) {
GGUF_GET_KEY(fctx, lora->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
GGUF_GET_KEY(fctx, lora->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, lora->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
GGUF_GET_KEY(fctx, lora->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
GGUF_GET_KEY(fctx, lora->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
GGUF_GET_KEY(fctx, lora->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
GGUF_GET_KEY(fctx, lora->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, lora->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
} }
load_opt_context_gguf(fctx, f_ggml_ctx, opt); static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
}
static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt) {
save_llama_lora_gguf(fctx, model, lora); save_llama_lora_gguf(fctx, model, lora);
save_train_state_gguf(fctx, train);
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, lora->train_its);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, lora->train_samples);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, lora->train_tokens);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, lora->train_epochs);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) lora->shuffle_samples_hash);
gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, lora->shuffle_rng_state_current.c_str());
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) lora->shuffle_sample_count);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) lora->shuffle_next_sample);
save_opt_context_gguf(fctx, opt);
} }
static bool load_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt) { static bool load_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
struct ggml_context * f_ggml_ctx; struct ggml_context * f_ggml_ctx;
struct gguf_init_params params; struct gguf_init_params params;
params.no_alloc = false; params.no_alloc = false;
@ -1082,19 +1010,19 @@ static bool load_checkpoint_lora_file(const char * filename, struct my_llama_mod
return false; return false;
} }
load_checkpoint_lora_gguf(fctx, f_ggml_ctx, model, lora, opt); load_checkpoint_lora_gguf(fctx, f_ggml_ctx, model, lora, train);
gguf_free(fctx); gguf_free(fctx);
return true; return true;
} }
static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt, const char * pattern_it, int iteration, const char * latest) { static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train, const char * pattern_it, int iteration, const char * latest) {
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
std::string fn = replace_str(filename, pattern_it, sit.c_str()); std::string fn = replace_str(filename, pattern_it, sit.c_str());
printf("%s: saving to %s\n", __func__, fn.c_str()); printf("%s: saving to %s\n", __func__, fn.c_str());
struct gguf_context * fctx = gguf_init_empty(); struct gguf_context * fctx = gguf_init_empty();
save_checkpoint_lora_gguf(fctx, model, lora, opt); save_checkpoint_lora_gguf(fctx, model, lora, train);
// write file // write file
const bool only_meta = false; const bool only_meta = false;
@ -1897,7 +1825,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
struct opt_callback_data { struct opt_callback_data {
struct train_params * params; struct train_params * params;
struct ggml_opt_context * opt; struct train_state * train;
struct my_llama_model * model; struct my_llama_model * model;
struct my_llama_lora * lora; struct my_llama_lora * lora;
struct llama_context * lctx; struct llama_context * lctx;
@ -1919,7 +1847,8 @@ struct opt_callback_data {
static void opt_callback(void * vdata, int accum_step, float * sched) { static void opt_callback(void * vdata, int accum_step, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata; struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params * params = data->params; struct train_params * params = data->params;
struct ggml_opt_context * opt = data->opt; struct train_state * train = data->train;
struct ggml_opt_context * opt = train->opt;
int n_batch = params->n_batch; int n_batch = params->n_batch;
int n_ctx = params->n_ctx; int n_ctx = params->n_ctx;
@ -1948,13 +1877,13 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
if (save_now) { if (save_now) {
int new_iters = opt->iter - data->last_save_iter; int new_iters = opt->iter - data->last_save_iter;
data->lora->train_its += new_iters; train->train_its += new_iters;
data->lora->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
data->lora->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (strlen(params->fn_checkpoint_out) > 0) { if (strlen(params->fn_checkpoint_out) > 0) {
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, opt->iter, params->fn_latest); save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, opt->iter, params->fn_latest);
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, -1, params->fn_latest); save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, -1, params->fn_latest);
} }
if (strlen(params->fn_lora_out) > 0) { if (strlen(params->fn_lora_out) > 0) {
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest); save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest);
@ -1980,7 +1909,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
if (impr_plot > 0) impr_plot = 0; if (impr_plot > 0) impr_plot = 0;
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0; if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f", printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
__func__, opt->iter, std::min(1+data->lora->shuffle_next_sample, data->lora->shuffle_sample_count), data->lora->shuffle_sample_count, __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
*sched, opt->loss_after); *sched, opt->loss_after);
@ -2006,7 +1935,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
data->lctx, data->lctx,
data->tokens_input, data->tokens_input,
data->target_probs, data->target_probs,
data->lora->shuffle_next_sample, train->shuffle_next_sample,
data->shuffled_samples_begin, data->shuffled_samples_begin,
data->shuffled_samples_size, data->shuffled_samples_size,
data->samples_count, data->samples_count,
@ -2016,21 +1945,21 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
params->separate_with_bos, params->separate_with_bos,
params->fill_with_next_samples); params->fill_with_next_samples);
data->lora->shuffle_next_sample += used_samples; train->shuffle_next_sample += used_samples;
if (data->lora->shuffle_next_sample >= data->lora->shuffle_sample_count) { if (train->shuffle_next_sample >= train->shuffle_sample_count) {
++data->lora->train_epochs; ++train->train_epochs;
printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) data->lora->train_epochs); printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
// note: we may have used some samples from the current shuffling more than once // note: we may have used some samples from the current shuffling more than once
data->lora->shuffle_rng_state_current = data->lora->shuffle_rng_state_next; train->shuffle_rng_state_current = train->shuffle_rng_state_next;
data->lora->shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_next = shuffle_samples(
data->lora->shuffle_rng_state_current, train->shuffle_rng_state_current,
data->shuffled_samples_begin, data->shuffled_samples_begin,
data->shuffled_samples_size, data->shuffled_samples_size,
data->samples_begin, data->samples_begin,
data->samples_size, data->samples_size,
data->samples_count); data->samples_count);
data->lora->shuffle_next_sample = 0; train->shuffle_next_sample = 0;
} }
} }
@ -2091,10 +2020,9 @@ int main(int argc, char ** argv) {
init_model(lmodel, &model, params.n_ctx); init_model(lmodel, &model, params.n_ctx);
struct my_llama_lora lora; struct my_llama_lora lora;
struct ggml_opt_context* opt = (struct ggml_opt_context*)alloca(sizeof(struct ggml_opt_context));
memset(opt, 0, sizeof(struct ggml_opt_context));
opt->ctx = NULL; struct train_state * train = init_train_state(params.seed);
struct ggml_opt_context * opt = train->opt;
load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams); load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams);
@ -2157,7 +2085,7 @@ int main(int argc, char ** argv) {
ggml_allocr * alloc = NULL; ggml_allocr * alloc = NULL;
printf("%s: init model\n", __func__); printf("%s: init model\n", __func__);
bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, opt); bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, train);
if (existed) { if (existed) {
// overwrite last n_ctx with user provided n_ctx // overwrite last n_ctx with user provided n_ctx
@ -2203,13 +2131,13 @@ int main(int argc, char ** argv) {
print_params(&model.hparams); print_params(&model.hparams);
print_lora_params(&lora.hparams); print_lora_params(&lora.hparams);
printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) lora.train_its); printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) train->train_its);
printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) lora.train_samples); printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) lora.train_tokens); printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) lora.train_epochs); printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
printf("%s: max_lora_size = %zu bytes (%.1f MB)\n", __func__, lora.data.size(), (float) lora.data.size() / (1024.0f*1024.0f)); printf("%s: max_lora_size = %zu bytes (%.1f MB)\n", __func__, lora.data.size(), (float) lora.data.size() / (1024.0f*1024.0f));
printf("%s: max_opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f)); printf("%s: max_opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
opt->iter = lora.train_its; opt->iter = train->train_its;
if (params.only_write_lora) { if (params.only_write_lora) {
if (strlen(params.fn_lora_out) > 0) { if (strlen(params.fn_lora_out) > 0) {
@ -2368,25 +2296,25 @@ int main(int argc, char ** argv) {
printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens); printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
const bool changed_train_data = (shuffle_samples_hash != lora.shuffle_samples_hash) || (lora.shuffle_sample_count != train_samples_size.size()); const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
if (changed_train_data) { if (changed_train_data) {
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__); printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
} }
if (params.force_reshuffle) { if (params.force_reshuffle) {
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__); printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
} }
if ((lora.shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) {
lora.shuffle_rng_state_current = mt19937_seed_to_state(params.seed); train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed);
lora.shuffle_sample_count = train_samples_size.size(); train->shuffle_sample_count = train_samples_size.size();
lora.shuffle_next_sample = 0; train->shuffle_next_sample = 0;
lora.shuffle_samples_hash = shuffle_samples_hash; train->shuffle_samples_hash = shuffle_samples_hash;
} }
std::vector<size_t> train_shuffled_samples_begin; std::vector<size_t> train_shuffled_samples_begin;
std::vector<size_t> train_shuffled_samples_size; std::vector<size_t> train_shuffled_samples_size;
train_shuffled_samples_begin.resize(train_samples_begin.size()); train_shuffled_samples_begin.resize(train_samples_begin.size());
train_shuffled_samples_size.resize(train_samples_size.size()); train_shuffled_samples_size.resize(train_samples_size.size());
lora.shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_next = shuffle_samples(
lora.shuffle_rng_state_current, train->shuffle_rng_state_current,
train_shuffled_samples_begin.data(), train_shuffled_samples_begin.data(),
train_shuffled_samples_size.data(), train_shuffled_samples_size.data(),
train_samples_begin.data(), train_samples_begin.data(),
@ -2397,7 +2325,7 @@ int main(int argc, char ** argv) {
struct opt_callback_data opt_cb_data; struct opt_callback_data opt_cb_data;
opt_cb_data.params = &params; opt_cb_data.params = &params;
opt_cb_data.opt = opt; opt_cb_data.train = train;
opt_cb_data.model = &model; opt_cb_data.model = &model;
opt_cb_data.lora = &lora; opt_cb_data.lora = &lora;
opt_cb_data.lctx = lctx; opt_cb_data.lctx = lctx;
@ -2442,13 +2370,13 @@ int main(int argc, char ** argv) {
int new_iters = opt->iter - opt_cb_data.last_save_iter; int new_iters = opt->iter - opt_cb_data.last_save_iter;
if (new_iters > 0) { if (new_iters > 0) {
lora.train_its += new_iters; train->train_its += new_iters;
lora.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
lora.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
if (strlen(params.fn_checkpoint_out) > 0) { if (strlen(params.fn_checkpoint_out) > 0) {
save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, opt, params.pattern_fn_it, opt->iter, params.fn_latest); save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, opt->iter, params.fn_latest);
save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, opt, params.pattern_fn_it, -1, params.fn_latest); save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, -1, params.fn_latest);
} }
if (strlen(params.fn_lora_out) > 0) { if (strlen(params.fn_lora_out) > 0) {
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest); save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest);
@ -2458,6 +2386,7 @@ int main(int argc, char ** argv) {
} }
ggml_free(opt->ctx); ggml_free(opt->ctx);
free_train_state(train);
ggml_free(lora.ctx); ggml_free(lora.ctx);
llama_free(lctx); llama_free(lctx);
llama_free_model(lmodel); llama_free_model(lmodel);

View file

@ -65,34 +65,8 @@ struct my_llama_model {
struct ggml_tensor * output; struct ggml_tensor * output;
std::vector<my_llama_layer> layers; std::vector<my_llama_layer> layers;
uint64_t train_its = 0;
uint64_t train_samples = 0;
uint64_t train_tokens = 0;
uint64_t train_epochs = 0;
size_t shuffle_samples_hash; // fn, sample_count, *zip(sample_begins, sample_sizes)
std::string shuffle_rng_state_current;
std::string shuffle_rng_state_next;
size_t shuffle_sample_count;
size_t shuffle_next_sample;
}; };
// gguf constants
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
static const char * LLM_KV_TRAINING_TYPE = "training.type";
static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
static const char * LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE = "training.shuffle.next_sample";
// gguf constants (sync with gguf.py) // gguf constants (sync with gguf.py)
static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture"; static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
@ -152,11 +126,6 @@ static void init_model(struct my_llama_model * model) {
struct ggml_context * ctx = model->ctx; struct ggml_context * ctx = model->ctx;
model->train_its = 0;
model->train_samples = 0;
model->train_tokens = 0;
model->train_epochs = 0;
std::vector<char> tn_buf; std::vector<char> tn_buf;
tn_buf.resize(GGML_MAX_NAME); tn_buf.resize(GGML_MAX_NAME);
auto tn = [&tn_buf](const char * key) -> const char * { auto tn = [&tn_buf](const char * key) -> const char * {
@ -685,62 +654,19 @@ static void save_llama_model_file(const char * filename, const char * fn_vocab_m
gguf_free(fctx); gguf_free(fctx);
} }
static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) { static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct train_state * train) {
load_llama_model_gguf(fctx, f_ggml_ctx, model); load_llama_model_gguf(fctx, f_ggml_ctx, model);
if (!load_train_state_gguf(fctx, f_ggml_ctx, train)) {
if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) >= 0) {
uint32_t file_version = 0xFFFFFFFFu;
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
GGML_ASSERT(file_version <= 1);
std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
if (file_version == 0) {
GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
} else if (file_version == 1) {
GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_ITERATION_COUNT);
GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_TOKEN_COUNT);
GGUF_GET_KEY(fctx, model->train_epochs, gguf_get_val_u64, GGUF_TYPE_UINT64, true, LLM_KV_TRAINING_EPOCH_COUNT);
GGUF_GET_KEY(fctx, model->shuffle_samples_hash, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH);
GGUF_GET_KEY(fctx, model->shuffle_rng_state_current, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_SHUFFLE_RNG_STATE);
GGUF_GET_KEY(fctx, model->shuffle_sample_count, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, model->shuffle_next_sample, gguf_get_val_u64, GGUF_TYPE_UINT64, false, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE);
}
load_opt_context_gguf(fctx, f_ggml_ctx, opt);
} else {
printf("%s: loaded llama model as checkpoint\n", __func__); printf("%s: loaded llama model as checkpoint\n", __func__);
} }
} }
static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) { static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
save_llama_model_gguf(fctx, fn_vocab_model, model); save_llama_model_gguf(fctx, fn_vocab_model, model);
save_train_state_gguf(fctx, train);
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_EPOCH_COUNT, model->train_epochs);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH, (uint64_t) model->shuffle_samples_hash);
gguf_set_val_str(fctx, LLM_KV_TRAINING_SHUFFLE_RNG_STATE, model->shuffle_rng_state_current.c_str());
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT, (uint64_t) model->shuffle_sample_count);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SHUFFLE_NEXT_SAMPLE, (uint64_t) model->shuffle_next_sample);
save_opt_context_gguf(fctx, opt);
} }
static bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct ggml_opt_context * opt) { static bool load_checkpoint_file(const char * filename, struct my_llama_model * model, struct train_state * train) {
struct ggml_context * f_ggml_ctx; struct ggml_context * f_ggml_ctx;
struct gguf_init_params params; struct gguf_init_params params;
params.no_alloc = false; params.no_alloc = false;
@ -750,18 +676,18 @@ static bool load_checkpoint_file(const char * filename, struct my_llama_model *
return false; return false;
} }
load_checkpoint_gguf(fctx, f_ggml_ctx, model, opt); load_checkpoint_gguf(fctx, f_ggml_ctx, model, train);
return true; return true;
} }
static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt, const char * pattern_it, int iteration, const char * latest) { static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train, const char * pattern_it, int iteration, const char * latest) {
std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest);
std::string fn = replace_str(filename, pattern_it, sit.c_str()); std::string fn = replace_str(filename, pattern_it, sit.c_str());
printf("%s: saving to %s\n", __func__, fn.c_str()); printf("%s: saving to %s\n", __func__, fn.c_str());
struct gguf_context * fctx = gguf_init_empty(); struct gguf_context * fctx = gguf_init_empty();
save_checkpoint_gguf(fctx, fn_vocab_model, model, opt); save_checkpoint_gguf(fctx, fn_vocab_model, model, train);
// write file // write file
const bool only_meta = false; const bool only_meta = false;
@ -1296,7 +1222,7 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
struct opt_callback_data { struct opt_callback_data {
struct train_params * params; struct train_params * params;
struct ggml_opt_context * opt; struct train_state * train;
struct my_llama_model * model; struct my_llama_model * model;
struct llama_context * lctx; struct llama_context * lctx;
int last_save_iter; int last_save_iter;
@ -1318,7 +1244,8 @@ struct opt_callback_data {
static void opt_callback(void * vdata, int accum_step, float * sched) { static void opt_callback(void * vdata, int accum_step, float * sched) {
struct opt_callback_data * data = (struct opt_callback_data *) vdata; struct opt_callback_data * data = (struct opt_callback_data *) vdata;
struct train_params * params = data->params; struct train_params * params = data->params;
struct ggml_opt_context * opt = data->opt; struct train_state * train = data->train;
struct ggml_opt_context * opt = train->opt;
int n_batch = params->n_batch; int n_batch = params->n_batch;
int n_ctx = params->n_ctx; int n_ctx = params->n_ctx;
@ -1347,13 +1274,13 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
if (save_now) { if (save_now) {
int new_iters = opt->iter - data->last_save_iter; int new_iters = opt->iter - data->last_save_iter;
data->model->train_its += new_iters; train->train_its += new_iters;
data->model->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
data->model->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (strlen(params->fn_checkpoint_out) > 0) { if (strlen(params->fn_checkpoint_out) > 0) {
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest); save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, opt->iter, params->fn_latest);
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, -1, params->fn_latest); save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, -1, params->fn_latest);
} }
if (strlen(params->fn_model_out) > 0) { if (strlen(params->fn_model_out) > 0) {
@ -1380,7 +1307,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
if (impr_plot > 0) impr_plot = 0; if (impr_plot > 0) impr_plot = 0;
if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0; if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0;
printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f", printf("%s: iter=%6d sample=%zu/%zu sched=%f loss=%f",
__func__, opt->iter, std::min(1+data->model->shuffle_next_sample, data->model->shuffle_sample_count), data->model->shuffle_sample_count, __func__, opt->iter, std::min(1+train->shuffle_next_sample, train->shuffle_sample_count), train->shuffle_sample_count,
*sched, opt->loss_after); *sched, opt->loss_after);
@ -1406,7 +1333,7 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
data->lctx, data->lctx,
data->tokens_input, data->tokens_input,
data->target_probs, data->target_probs,
data->model->shuffle_next_sample, train->shuffle_next_sample,
data->shuffled_samples_begin, data->shuffled_samples_begin,
data->shuffled_samples_size, data->shuffled_samples_size,
data->samples_count, data->samples_count,
@ -1416,21 +1343,21 @@ static void opt_callback(void * vdata, int accum_step, float * sched) {
params->separate_with_bos, params->separate_with_bos,
params->fill_with_next_samples); params->fill_with_next_samples);
data->model->shuffle_next_sample += used_samples; train->shuffle_next_sample += used_samples;
if (data->model->shuffle_next_sample >= data->model->shuffle_sample_count) { if (train->shuffle_next_sample >= train->shuffle_sample_count) {
++data->model->train_epochs; ++train->train_epochs;
printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) data->model->train_epochs); printf("%s: reshuffle samples. completed epochs: %llu\n", __func__, (long long unsigned) train->train_epochs);
// note: we may have used some samples from the current shuffling more than once // note: we may have used some samples from the current shuffling more than once
data->model->shuffle_rng_state_current = data->model->shuffle_rng_state_next; train->shuffle_rng_state_current = train->shuffle_rng_state_next;
data->model->shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_next = shuffle_samples(
data->model->shuffle_rng_state_current, train->shuffle_rng_state_current,
data->shuffled_samples_begin, data->shuffled_samples_begin,
data->shuffled_samples_size, data->shuffled_samples_size,
data->samples_begin, data->samples_begin,
data->samples_size, data->samples_size,
data->samples_count); data->samples_count);
data->model->shuffle_next_sample = 0; train->shuffle_next_sample = 0;
} }
} }
@ -1480,8 +1407,8 @@ int main(int argc, char ** argv) {
int n_vocab = model.hparams.n_vocab; int n_vocab = model.hparams.n_vocab;
int n_batch = params.n_batch; int n_batch = params.n_batch;
struct ggml_opt_context * opt = (struct ggml_opt_context *) alloca(sizeof(struct ggml_opt_context)); struct train_state * train = init_train_state(params.seed);
memset(opt, 0, sizeof(struct ggml_opt_context)); struct ggml_opt_context * opt = train->opt;
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
opt_params_adam.print_forward_graph = false; opt_params_adam.print_forward_graph = false;
@ -1505,7 +1432,7 @@ int main(int argc, char ** argv) {
opt->params = opt_params_adam; opt->params = opt_params_adam;
printf("%s: init model\n", __func__); printf("%s: init model\n", __func__);
bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, opt); bool existed = load_checkpoint_file(params.fn_checkpoint_in, &model, train);
if (!existed) { if (!existed) {
init_model(&model); init_model(&model);
} }
@ -1513,7 +1440,7 @@ int main(int argc, char ** argv) {
opt->params = opt_params_adam; opt->params = opt_params_adam;
opt->iter = model.train_its; opt->iter = train->train_its;
printf("%s: opt iter %d\n", __func__, opt->iter); printf("%s: opt iter %d\n", __func__, opt->iter);
bool from_scratch = !existed; bool from_scratch = !existed;
@ -1555,25 +1482,25 @@ int main(int argc, char ** argv) {
printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size()); printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size()); size_t shuffle_samples_hash = compute_samples_hash(params.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
const bool changed_train_data = (shuffle_samples_hash != model.shuffle_samples_hash) || (model.shuffle_sample_count != train_samples_size.size()); const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
if (changed_train_data) { if (changed_train_data) {
printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__); printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
} }
if (params.force_reshuffle) { if (params.force_reshuffle) {
printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__); printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
} }
if ((model.shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) { if ((train->shuffle_rng_state_current == "") || changed_train_data || params.force_reshuffle) {
model.shuffle_rng_state_current = mt19937_seed_to_state(params.seed); train->shuffle_rng_state_current = mt19937_seed_to_state(params.seed);
model.shuffle_sample_count = train_samples_size.size(); train->shuffle_sample_count = train_samples_size.size();
model.shuffle_next_sample = 0; train->shuffle_next_sample = 0;
model.shuffle_samples_hash = shuffle_samples_hash; train->shuffle_samples_hash = shuffle_samples_hash;
} }
std::vector<size_t> train_shuffled_samples_begin; std::vector<size_t> train_shuffled_samples_begin;
std::vector<size_t> train_shuffled_samples_size; std::vector<size_t> train_shuffled_samples_size;
train_shuffled_samples_begin.resize(train_samples_begin.size()); train_shuffled_samples_begin.resize(train_samples_begin.size());
train_shuffled_samples_size.resize(train_samples_size.size()); train_shuffled_samples_size.resize(train_samples_size.size());
model.shuffle_rng_state_next = shuffle_samples( train->shuffle_rng_state_next = shuffle_samples(
model.shuffle_rng_state_current, train->shuffle_rng_state_current,
train_shuffled_samples_begin.data(), train_shuffled_samples_begin.data(),
train_shuffled_samples_size.data(), train_shuffled_samples_size.data(),
train_samples_begin.data(), train_samples_begin.data(),
@ -1583,7 +1510,7 @@ int main(int argc, char ** argv) {
struct opt_callback_data opt_cb_data; struct opt_callback_data opt_cb_data;
opt_cb_data.params = &params; opt_cb_data.params = &params;
opt_cb_data.opt = opt; opt_cb_data.train = train;
opt_cb_data.model = &model; opt_cb_data.model = &model;
opt_cb_data.lctx = lctx; opt_cb_data.lctx = lctx;
opt_cb_data.last_save_iter = opt->iter; opt_cb_data.last_save_iter = opt->iter;
@ -1672,9 +1599,9 @@ int main(int argc, char ** argv) {
size_t used_mem_after_opt = ggml_used_mem(ctx0); size_t used_mem_after_opt = ggml_used_mem(ctx0);
int n_iter = params.adam_n_iter; int n_iter = params.adam_n_iter;
model.train_its = opt->iter; train->train_its = opt->iter;
model.train_samples += n_batch * n_iter; train->train_samples += n_batch * n_iter;
model.train_tokens += n_batch * n_tokens * n_iter; train->train_tokens += n_batch * n_tokens * n_iter;
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
printf("Example %d, opt iter %d\n", ex, opt->iter); printf("Example %d, opt iter %d\n", ex, opt->iter);
@ -1693,13 +1620,13 @@ int main(int argc, char ** argv) {
printf("%s: total training time=%f seconds\n", __func__, dd); printf("%s: total training time=%f seconds\n", __func__, dd);
int new_iters = opt->iter - opt_cb_data.last_save_iter; int new_iters = opt->iter - opt_cb_data.last_save_iter;
model.train_its += new_iters; train->train_its += new_iters;
model.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
model.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
if (params.n_examples > 0) { if (params.n_examples > 0) {
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, opt->iter, params.fn_latest); save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, opt->iter, params.fn_latest);
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, -1, params.fn_latest); save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, -1, params.fn_latest);
} }
if (strlen(params.fn_model_out) > 0) { if (strlen(params.fn_model_out) > 0) {
@ -1715,6 +1642,7 @@ int main(int argc, char ** argv) {
delete[] compute_addr; delete[] compute_addr;
delete[] compute_buf_0; delete[] compute_buf_0;
free_train_state(train);
ggml_free(model.ctx); ggml_free(model.ctx);
llama_free(lctx); llama_free(lctx);
llama_free_model(lmodel); llama_free_model(lmodel);