add export of training checkpoint to llama compatible model file

This commit is contained in:
xaedes 2023-05-29 01:27:09 +02:00
parent 2da5c8cf24
commit 4b81c32d5b
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -150,6 +150,19 @@ struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struc
return tensor;
}
struct llama_vocab {
using id = int32_t;
using token = std::string;
struct token_score {
token tok;
float score;
};
std::unordered_map<token, id> token_to_id;
std::vector<token_score> id_to_token;
};
struct my_llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512; // this is provided as user input?
@ -278,9 +291,20 @@ void init_model(struct my_llama_model * model) {
ggml_set_name(layer.ffn_norm, (layers_i + ".ffn_norm.weight").c_str());
ggml_set_name(layer.w1, (layers_i + ".feed_forward.w1.weight").c_str());
ggml_set_name(layer.w2, (layers_i + ".feed_forward.w2.weight").c_str());
ggml_set_name(layer.w3, (layers_i + ".feed_forward.w3.weight").c_str());
// 'layers.10.feed_forward.w1.weight' has length of 32.
// ggml_tensor->name only has 32 characters, but we need one more for the '\0' terminator.
// ggml_set_name will set the last character to '\0', so we can only store 'layers.10.feed_forward.w1.weigh'.
// when saving llama compatible model the tensors names will miss a character.
// ggml_set_name(layer.w1, (layers_i + ".feed_forward.w1.weight").c_str());
// ggml_set_name(layer.w2, (layers_i + ".feed_forward.w2.weight").c_str());
// ggml_set_name(layer.w3, (layers_i + ".feed_forward.w3.weight").c_str());
strncpy(layer.w1->name, (layers_i + ".feed_forward.w1.weight").c_str(), sizeof(layer.w1->name));
strncpy(layer.w2->name, (layers_i + ".feed_forward.w2.weight").c_str(), sizeof(layer.w2->name));
strncpy(layer.w3->name, (layers_i + ".feed_forward.w3.weight").c_str(), sizeof(layer.w3->name));
layer.w1->padding[0] = 0;
layer.w2->padding[0] = 0;
layer.w3->padding[0] = 0;
}
}
@ -1584,13 +1608,6 @@ void set_logits_masked(struct ggml_tensor * logits, std::vector<bool>& mask, flo
}
}
enum llama_file_version {
LLAMA_FILE_VERSION_GGML,
LLAMA_FILE_VERSION_GGMF_V1, // added version field and scores in vocab
LLAMA_FILE_VERSION_GGJT_V1, // added padding
LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format
};
void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
if (tensor == NULL) {
file->write_u32(0);
@ -1627,7 +1644,7 @@ void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
}
std::string name = file->read_string(name_len);
GGML_ASSERT(strcmp(ggml_get_name(tensor), name.c_str()) == 0);
GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)) == 0);
file->seek(-file->tell() & 31, SEEK_CUR);
file->read_raw(tensor->data, ggml_nbytes(tensor));
@ -1839,6 +1856,50 @@ bool load_checkpoint(struct my_llama_model * model, struct ggml_opt_context * op
return (file.fp != NULL);
}
void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * model, const char * filename) {
struct llama_file file(filename, "wb");
if (file.fp == NULL) {
return;
}
// write_magic
file.write_u32(LLAMA_FILE_MAGIC); // magic
file.write_u32(LLAMA_FILE_VERSION); // version
// write_hparams
file.write_u32(model->hparams.n_vocab);
file.write_u32(model->hparams.n_embd);
file.write_u32(model->hparams.n_mult);
file.write_u32(model->hparams.n_head);
file.write_u32(model->hparams.n_layer);
file.write_u32(model->hparams.n_rot);
file.write_u32(LLAMA_FTYPE_ALL_F32);
// write_vocab
uint32_t n_vocab = model->hparams.n_vocab;
for (uint32_t i = 0; i < n_vocab; i++) {
const auto & token_score = vocab->id_to_token.at(i);
file.write_u32((uint32_t) token_score.tok.size());
file.write_raw(token_score.tok.data(), token_score.tok.size());
file.write_raw(&token_score.score, sizeof(token_score.score));
}
// write tensors
write_tensor(&file, model->tok_embeddings);
write_tensor(&file, model->norm);
write_tensor(&file, model->output);
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
auto & layer = model->layers[i];
write_tensor(&file, layer.attention_norm);
write_tensor(&file, layer.wq);
write_tensor(&file, layer.wk);
write_tensor(&file, layer.wv);
write_tensor(&file, layer.wo);
write_tensor(&file, layer.ffn_norm);
write_tensor(&file, layer.w1);
write_tensor(&file, layer.w2);
write_tensor(&file, layer.w3);
}
}
float cosine_decay(const int decay_steps, const float alpha, int step) {
if (step > decay_steps) {
step = decay_steps;
@ -1861,10 +1922,11 @@ int main(int argc, char ** argv) {
const char * default_train = "shakespeare.txt";
const char * default_chkpt_in = "checkpoint.bin";
const char * default_chkpt_out = "checkpoint.bin";
const char * default_argv[5] = {argv[0], default_model, default_train, default_chkpt_in, default_chkpt_out};
const char * default_model_out = "ggml-checkpoint-f32.bin";
const char * default_argv[6] = {argv[0], default_model, default_train, default_chkpt_in, default_chkpt_out, default_model_out};
if (argc < 5) {
fprintf(stderr, "usage: %s model training_data chkpt_in chkpt_out\n", argv[0]);
if (argc < 6) {
fprintf(stderr, "usage: %s model training_data chkpt_in chkpt_out model_out\n", argv[0]);
//return 1;
}
@ -1874,6 +1936,7 @@ int main(int argc, char ** argv) {
const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2];
const char * fn_chkpt_in = (argc >= 4) ? argv[3] : default_argv[3];
const char * fn_chkpt_out = (argc >= 5) ? argv[4] : default_argv[4];
const char * fn_model_out = (argc >= 6) ? argv[5] : default_argv[5];
struct llama_context_params llama_params = llama_context_default_params();
llama_params.vocab_only = true;
@ -1970,6 +2033,8 @@ int main(int argc, char ** argv) {
bool existed = load_checkpoint(&model, opt, fn_chkpt_in, true);
set_param_model(&model);
opt->params = use_adam ? opt_params_adam : opt_params_lbfgs;
opt->iter = model.train_its;
printf("%s: opt iter %d\n", __func__, opt->iter);
@ -2105,6 +2170,10 @@ int main(int argc, char ** argv) {
save_checkpoint(&model, opt, fn_chkpt_out);
}
if (strlen(fn_model_out) > 0) {
save_as_llama_model(&vocab, &model, fn_model_out);
}
{
int n_gen = 1024;
int sample_ctx = n_tokens - n_tokens/8;