fix saving and loading of training type

This commit is contained in:
xaedes 2023-09-16 21:21:04 +02:00
parent 1d09965179
commit 9db2664dd1
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
3 changed files with 19 additions and 11 deletions

View file

@ -497,15 +497,11 @@ static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
static const char * LLM_KV_TRAINING_TYPE = "training.type";
static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
@ -661,10 +657,6 @@ bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_g
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
GGML_ASSERT(file_version <= 1);
std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
if (file_version == 0) {
GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
@ -690,7 +682,6 @@ bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_g
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples);
gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens);

View file

@ -151,6 +151,10 @@ struct my_llama_lora {
};
// gguf constants
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
static const char * LLM_KV_TRAINING_TYPE = "training.type";
static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd";
static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm";
static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output";
@ -994,11 +998,16 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod
}
static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora);
std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
load_train_state_gguf(fctx, f_ggml_ctx, train);
load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora);
}
static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
save_llama_lora_gguf(fctx, model, lora);
save_train_state_gguf(fctx, train);
}

View file

@ -68,6 +68,9 @@ struct my_llama_model {
};
// gguf constants (sync with gguf.py)
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
static const char * LLM_KV_TRAINING_TYPE = "training.type";
static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
@ -654,12 +657,17 @@ static void save_llama_model_file(const char * filename, const char * fn_vocab_m
static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct train_state * train) {
load_llama_model_gguf(fctx, f_ggml_ctx, model);
if (!load_train_state_gguf(fctx, f_ggml_ctx, train)) {
if (load_train_state_gguf(fctx, f_ggml_ctx, train)) {
std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
} else {
printf("%s: loaded llama model as checkpoint\n", __func__);
}
}
static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
save_llama_model_gguf(fctx, fn_vocab_model, model);
save_train_state_gguf(fctx, train);
}