fix saving and loading of training type
This commit is contained in:
parent
1d09965179
commit
9db2664dd1
3 changed files with 19 additions and 11 deletions
|
@ -497,15 +497,11 @@ static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.
|
|||
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
|
||||
static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
|
||||
|
||||
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
|
||||
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
|
||||
static const char * LLM_KV_TRAINING_TYPE = "training.type";
|
||||
static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
|
||||
static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
|
||||
static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
|
||||
static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
|
||||
static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count";
|
||||
static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash";
|
||||
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash";
|
||||
static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state";
|
||||
static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count";
|
||||
|
@ -661,10 +657,6 @@ bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_g
|
|||
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
|
||||
GGML_ASSERT(file_version <= 1);
|
||||
|
||||
std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
|
||||
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
|
||||
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
|
||||
|
||||
if (file_version == 0) {
|
||||
|
||||
GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
|
||||
|
@ -690,7 +682,6 @@ bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_g
|
|||
|
||||
void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) {
|
||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1);
|
||||
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
|
||||
gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its);
|
||||
gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples);
|
||||
gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens);
|
||||
|
|
|
@ -151,6 +151,10 @@ struct my_llama_lora {
|
|||
};
|
||||
|
||||
// gguf constants
|
||||
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
|
||||
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
|
||||
static const char * LLM_KV_TRAINING_TYPE = "training.type";
|
||||
|
||||
static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd";
|
||||
static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm";
|
||||
static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output";
|
||||
|
@ -994,11 +998,16 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod
|
|||
}
|
||||
|
||||
static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
|
||||
load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora);
|
||||
std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
|
||||
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
|
||||
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
|
||||
|
||||
load_train_state_gguf(fctx, f_ggml_ctx, train);
|
||||
load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora);
|
||||
}
|
||||
|
||||
static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
|
||||
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
|
||||
save_llama_lora_gguf(fctx, model, lora);
|
||||
save_train_state_gguf(fctx, train);
|
||||
}
|
||||
|
|
|
@ -68,6 +68,9 @@ struct my_llama_model {
|
|||
};
|
||||
|
||||
// gguf constants (sync with gguf.py)
|
||||
static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
|
||||
static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
|
||||
static const char * LLM_KV_TRAINING_TYPE = "training.type";
|
||||
|
||||
static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
|
||||
static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
|
||||
|
@ -654,12 +657,17 @@ static void save_llama_model_file(const char * filename, const char * fn_vocab_m
|
|||
|
||||
static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct train_state * train) {
|
||||
load_llama_model_gguf(fctx, f_ggml_ctx, model);
|
||||
if (!load_train_state_gguf(fctx, f_ggml_ctx, train)) {
|
||||
if (load_train_state_gguf(fctx, f_ggml_ctx, train)) {
|
||||
std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL;
|
||||
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
|
||||
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
|
||||
} else {
|
||||
printf("%s: loaded llama model as checkpoint\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) {
|
||||
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
|
||||
save_llama_model_gguf(fctx, fn_vocab_model, model);
|
||||
save_train_state_gguf(fctx, train);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue