fix saving and loading of training type

This commit is contained in:
xaedes 2023-09-16 21:21:04 +02:00
parent 1d09965179
commit 9db2664dd1
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
3 changed files with 19 additions and 11 deletions

View file

@ -497,15 +497,11 @@ static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"; static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"; static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
static const char * LLM_KV_TRAINING_TYPE = "training.type";
static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count"; static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash"; static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state"; static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count"; static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
@ -661,10 +657,6 @@ bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_g
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
GGML_ASSERT(file_version <= 1); GGML_ASSERT(file_version <= 1);
std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
if (file_version == 0) { if (file_version == 0) {
GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
@ -690,7 +682,6 @@ bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_g
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) { void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1); gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its); gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples); gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens); gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens);

View file

@ -151,6 +151,10 @@ struct my_llama_lora {
}; };
// gguf constants // gguf constants
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
static const char * LLM_KV_TRAINING_TYPE = "training.type";
static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd"; static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd";
static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm"; static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm";
static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output"; static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output";
@ -994,11 +998,16 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod
} }
static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) { static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora); std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
load_train_state_gguf(fctx, f_ggml_ctx, train); load_train_state_gguf(fctx, f_ggml_ctx, train);
load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora);
} }
static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) { static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
save_llama_lora_gguf(fctx, model, lora); save_llama_lora_gguf(fctx, model, lora);
save_train_state_gguf(fctx, train); save_train_state_gguf(fctx, train);
} }

View file

@ -68,6 +68,9 @@ struct my_llama_model {
}; };
// gguf constants (sync with gguf.py) // gguf constants (sync with gguf.py)
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
static const char * LLM_KV_TRAINING_TYPE = "training.type";
static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture"; static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type"; static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
@ -654,12 +657,17 @@ static void save_llama_model_file(const char * filename, const char * fn_vocab_m
static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct train_state * train) { static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct train_state * train) {
load_llama_model_gguf(fctx, f_ggml_ctx, model); load_llama_model_gguf(fctx, f_ggml_ctx, model);
if (!load_train_state_gguf(fctx, f_ggml_ctx, train)) { if (load_train_state_gguf(fctx, f_ggml_ctx, train)) {
std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
} else {
printf("%s: loaded llama model as checkpoint\n", __func__); printf("%s: loaded llama model as checkpoint\n", __func__);
} }
} }
static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) { static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
save_llama_model_gguf(fctx, fn_vocab_model, model); save_llama_model_gguf(fctx, fn_vocab_model, model);
save_train_state_gguf(fctx, train); save_train_state_gguf(fctx, train);
} }