From fc456edda6b4b38308d193a4cb8bac4cfe92f35a Mon Sep 17 00:00:00 2001 From: xaedes Date: Wed, 30 Aug 2023 15:57:17 +0200 Subject: [PATCH] train-text-from-scratch can train (full finetune) gguf models just pass the gguf model via `--checkpoint-in FN`. after this, to continue training, pass the generated checkpoint instead of the original gguf model. tested with smaller models, bigger models may exceed available memory. use (LORA) finetune for those. --- .../train-text-from-scratch.cpp | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index aaae5e576..76cf501a5 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1188,19 +1188,23 @@ void save_llama_model_file(const char * filename, const char * fn_vocab_model, s void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) { load_llama_model_gguf(fctx, f_ggml_ctx, model); - uint32_t file_version; - GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); - GGML_ASSERT(file_version == 0); + if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) >= 0) { + uint32_t file_version = 0xFFFFFFFFu; + GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); + GGML_ASSERT(file_version == 0); - std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL; - GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); - GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL); + std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL; + GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); + GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL); - GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); - GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); - GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); + GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); + GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); + GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); - load_opt_context_gguf(fctx, f_ggml_ctx, opt); + load_opt_context_gguf(fctx, f_ggml_ctx, opt); + } else { + printf("%s: loaded llama model as checkpoint\n", __func__); + } } void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {