diff --git a/common/train.cpp b/common/train.cpp index 1eec3e3fb..3724d75c2 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -497,15 +497,11 @@ static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer. static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"; static const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"; -static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"; -static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"; -static const char * LLM_KV_TRAINING_TYPE = "training.type"; static const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version"; static const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"; static const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"; static const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"; static const char * LLM_KV_TRAINING_EPOCH_COUNT = "training.epoch_count"; -static const char * LLM_KV_TRAINING_SAMPLES_HASH = "training.samples_hash"; static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLES_HASH = "training.shuffle.samples_hash"; static const char * LLM_KV_TRAINING_SHUFFLE_RNG_STATE = "training.shuffle.rng_state"; static const char * LLM_KV_TRAINING_SHUFFLE_SAMPLE_COUNT = "training.shuffle.sample_count"; @@ -661,10 +657,6 @@ bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_g GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); GGML_ASSERT(file_version <= 1); - std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA; - GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); - GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA); - if (file_version == 0) { GGUF_GET_KEY(fctx, train->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); @@ -690,7 +682,6 @@ bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_g void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train) { gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 1); - gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA); gguf_set_val_u64(fctx, LLM_KV_TRAINING_ITERATION_COUNT, train->train_its); gguf_set_val_u64(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, train->train_samples); gguf_set_val_u64(fctx, LLM_KV_TRAINING_TOKEN_COUNT, train->train_tokens); diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 548075493..ae3582a54 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -151,6 +151,10 @@ struct my_llama_lora { }; // gguf constants +static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"; +static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"; +static const char * LLM_KV_TRAINING_TYPE = "training.type"; + static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd"; static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm"; static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output"; @@ -994,11 +998,16 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod } static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) { - load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora); + std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA; + GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); + GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA); + load_train_state_gguf(fctx, f_ggml_ctx, train); + load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora); } static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) { + gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA); save_llama_lora_gguf(fctx, model, lora); save_train_state_gguf(fctx, train); } diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 88174e064..5c37776f3 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -68,6 +68,9 @@ struct my_llama_model { }; // gguf constants (sync with gguf.py) +static const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"; +static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"; +static const char * LLM_KV_TRAINING_TYPE = "training.type"; static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture"; static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type"; @@ -654,12 +657,17 @@ static void save_llama_model_file(const char * filename, const char * fn_vocab_m static void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct train_state * train) { load_llama_model_gguf(fctx, f_ggml_ctx, model); - if (!load_train_state_gguf(fctx, f_ggml_ctx, train)) { + if (load_train_state_gguf(fctx, f_ggml_ctx, train)) { + std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL; + GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); + GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL); + } else { printf("%s: loaded llama model as checkpoint\n", __func__); } } static void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) { + gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL); save_llama_model_gguf(fctx, fn_vocab_model, model); save_train_state_gguf(fctx, train); }