train-text-from-scratch can train (full finetune) gguf models

just pass the gguf model via `--checkpoint-in FN`.
after this, to continue training, pass the generated checkpoint instead of the original gguf model.

tested with smaller models, bigger models may exceed available memory.
use (LORA) finetune for those.
This commit is contained in:
xaedes 2023-08-30 15:57:17 +02:00
parent e6b7158123
commit fc456edda6
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1188,19 +1188,23 @@ void save_llama_model_file(const char * filename, const char * fn_vocab_model, s
void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) { void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct ggml_opt_context * opt) {
load_llama_model_gguf(fctx, f_ggml_ctx, model); load_llama_model_gguf(fctx, f_ggml_ctx, model);
uint32_t file_version; if (gguf_find_key(fctx, LLM_KV_TRAINING_FILE_VERSION) >= 0) {
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION); uint32_t file_version = 0xFFFFFFFFu;
GGML_ASSERT(file_version == 0); GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
GGML_ASSERT(file_version == 0);
std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL; std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL;
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE); GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL); GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT); GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT); GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT); GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
load_opt_context_gguf(fctx, f_ggml_ctx, opt); load_opt_context_gguf(fctx, f_ggml_ctx, opt);
} else {
printf("%s: loaded llama model as checkpoint\n", __func__);
}
} }
void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) { void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) {