add LLM_KV_TRAINING_TYPE to train-text-from-scratch checkpoints
so that they can be differentiated from lora finetune checkpoints
This commit is contained in:
parent
ca97583f0b
commit
e030f7b2c5
2 changed files with 20 additions and 8 deletions
|
@ -44,10 +44,13 @@ LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys"
|
||||||
LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"
|
LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s"
|
||||||
LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"
|
LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y"
|
||||||
|
|
||||||
LLM_KV_TRAINING_FILE_VERSION = "training.file_version"
|
LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model"
|
||||||
LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"
|
LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora"
|
||||||
LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"
|
LLM_KV_TRAINING_TYPE = "training.type"
|
||||||
LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"
|
LLM_KV_TRAINING_FILE_VERSION = "training.file_version"
|
||||||
|
LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count"
|
||||||
|
LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count"
|
||||||
|
LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count"
|
||||||
|
|
||||||
class Tensor:
|
class Tensor:
|
||||||
def __init__(self, dtype='f', ne=None):
|
def __init__(self, dtype='f', ne=None):
|
||||||
|
@ -457,6 +460,7 @@ class Checkpoint:
|
||||||
gguf_writer.add_file_type(gguf.GGMLQuantizationType.F32)
|
gguf_writer.add_file_type(gguf.GGMLQuantizationType.F32)
|
||||||
gguf_writer.add_layer_norm_rms_eps(1e-5)
|
gguf_writer.add_layer_norm_rms_eps(1e-5)
|
||||||
gguf_writer.add_uint32(LLM_KV_TRAINING_FILE_VERSION, 0)
|
gguf_writer.add_uint32(LLM_KV_TRAINING_FILE_VERSION, 0)
|
||||||
|
gguf_writer.add_string(LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL)
|
||||||
gguf_writer.add_uint32(LLM_KV_TRAINING_ITERATION_COUNT, self.train_its)
|
gguf_writer.add_uint32(LLM_KV_TRAINING_ITERATION_COUNT, self.train_its)
|
||||||
gguf_writer.add_uint32(LLM_KV_TRAINING_SAMPLE_COUNT, self.train_samples)
|
gguf_writer.add_uint32(LLM_KV_TRAINING_SAMPLE_COUNT, self.train_samples)
|
||||||
gguf_writer.add_uint32(LLM_KV_TRAINING_TOKEN_COUNT, self.train_tokens)
|
gguf_writer.add_uint32(LLM_KV_TRAINING_TOKEN_COUNT, self.train_tokens)
|
||||||
|
|
|
@ -246,10 +246,13 @@ const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.m
|
||||||
const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s";
|
||||||
const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
|
const char * LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y";
|
||||||
|
|
||||||
const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
|
const char * LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model";
|
||||||
const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
|
const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
|
||||||
const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
|
const char * LLM_KV_TRAINING_TYPE = "training.type";
|
||||||
const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
|
const char * LLM_KV_TRAINING_FILE_VERSION = "training.file_version";
|
||||||
|
const char * LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count";
|
||||||
|
const char * LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count";
|
||||||
|
const char * LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count";
|
||||||
|
|
||||||
// gguf constants (sync with gguf.py)
|
// gguf constants (sync with gguf.py)
|
||||||
|
|
||||||
|
@ -1431,6 +1434,10 @@ void load_checkpoint_gguf(struct gguf_context * fctx, struct ggml_context * f_gg
|
||||||
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
|
GGUF_GET_KEY(fctx, file_version, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_FILE_VERSION);
|
||||||
GGML_ASSERT(file_version == 0);
|
GGML_ASSERT(file_version == 0);
|
||||||
|
|
||||||
|
std::string train_type = LLM_KV_TRAINING_TYPE_TRAIN_MODEL;
|
||||||
|
GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
|
||||||
|
GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
|
||||||
|
|
||||||
GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
|
GGUF_GET_KEY(fctx, model->train_its, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_ITERATION_COUNT);
|
||||||
GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
|
GGUF_GET_KEY(fctx, model->train_samples, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_SAMPLE_COUNT);
|
||||||
GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
|
GGUF_GET_KEY(fctx, model->train_tokens, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_TOKEN_COUNT);
|
||||||
|
@ -1442,6 +1449,7 @@ void save_checkpoint_gguf(struct gguf_context * fctx, const char * fn_vocab_mode
|
||||||
save_llama_model_gguf(fctx, fn_vocab_model, model);
|
save_llama_model_gguf(fctx, fn_vocab_model, model);
|
||||||
|
|
||||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 0);
|
gguf_set_val_u32(fctx, LLM_KV_TRAINING_FILE_VERSION, 0);
|
||||||
|
gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_TRAIN_MODEL);
|
||||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its);
|
gguf_set_val_u32(fctx, LLM_KV_TRAINING_ITERATION_COUNT, model->train_its);
|
||||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples);
|
gguf_set_val_u32(fctx, LLM_KV_TRAINING_SAMPLE_COUNT, model->train_samples);
|
||||||
gguf_set_val_u32(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens);
|
gguf_set_val_u32(fctx, LLM_KV_TRAINING_TOKEN_COUNT, model->train_tokens);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue