also support loading from llama2.c vocabulary
This commit is contained in:
parent
d2b95e7e70
commit
aa26201291
1 changed files with 57 additions and 25 deletions
|
@ -438,6 +438,11 @@ struct llama_file {
|
||||||
read_raw(&ret, sizeof(ret));
|
read_raw(&ret, sizeof(ret));
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
std::float_t read_f32() {
|
||||||
|
std::float_t ret;
|
||||||
|
read_raw(&ret, sizeof(ret));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
std::string read_string(std::uint32_t len) {
|
std::string read_string(std::uint32_t len) {
|
||||||
std::vector<char> chars(len);
|
std::vector<char> chars(len);
|
||||||
|
@ -491,30 +496,57 @@ void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
|
||||||
file->write_raw(tensor->data, ggml_nbytes(tensor));
|
file->write_raw(tensor->data, ggml_nbytes(tensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
void load_vocab(const char *filename, struct llama_vocab *vocab) {
|
bool is_ggml_file(const char *filename) {
|
||||||
struct llama_context_params llama_params = llama_context_default_params();
|
llama_file file(filename, "rb");
|
||||||
llama_params.vocab_only = true;
|
if (file.size < 4) {
|
||||||
|
return false;
|
||||||
struct llama_model * lmodel = llama_load_model_from_file(filename, llama_params);
|
}
|
||||||
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
|
uint32_t magic = file.read_u32();
|
||||||
|
return magic == LLAMA_FILE_MAGIC;
|
||||||
std::vector<const char *> strings;
|
}
|
||||||
std::vector<float> scores;
|
|
||||||
int n_vocab = llama_n_vocab(lctx);
|
void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) {
|
||||||
strings.resize(n_vocab, NULL);
|
// heuristic to infer whether vocab is from ggml or from llama2.c vocabulary
|
||||||
scores.resize(n_vocab, 0);
|
if (is_ggml_file(filename)) {
|
||||||
n_vocab = llama_get_vocab(lctx, strings.data(), scores.data(), n_vocab);
|
|
||||||
GGML_ASSERT(n_vocab == llama_n_vocab(lctx));
|
struct llama_context_params llama_params = llama_context_default_params();
|
||||||
vocab->id_to_token.resize(n_vocab);
|
llama_params.vocab_only = true;
|
||||||
for (int i=0; i<n_vocab; ++i) {
|
|
||||||
std::string tok = std::string(strings[i]);
|
struct llama_model * lmodel = llama_load_model_from_file(filename, llama_params);
|
||||||
float score = scores[i];
|
struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params);
|
||||||
vocab->id_to_token[i].tok = tok;
|
|
||||||
vocab->id_to_token[i].score = score;
|
std::vector<const char *> strings;
|
||||||
vocab->token_to_id.emplace(tok, i);
|
std::vector<float> scores;
|
||||||
|
int n_vocab = llama_n_vocab(lctx);
|
||||||
|
strings.resize(n_vocab, NULL);
|
||||||
|
scores.resize(n_vocab, 0);
|
||||||
|
n_vocab = llama_get_vocab(lctx, strings.data(), scores.data(), n_vocab);
|
||||||
|
GGML_ASSERT(n_vocab == llama_n_vocab(lctx));
|
||||||
|
vocab->id_to_token.resize(n_vocab);
|
||||||
|
for (int i=0; i<n_vocab; ++i) {
|
||||||
|
std::string tok = std::string(strings[i]);
|
||||||
|
float score = scores[i];
|
||||||
|
vocab->id_to_token[i].tok = tok;
|
||||||
|
vocab->id_to_token[i].score = score;
|
||||||
|
vocab->token_to_id.emplace(tok, i);
|
||||||
|
}
|
||||||
|
llama_free(lctx);
|
||||||
|
llama_free_model(lmodel);
|
||||||
|
} else { // assume llama2.c vocabulary
|
||||||
|
printf("Assuming llama2.c vocabulary since %s is not a ggml file\n", filename);
|
||||||
|
llama_file file(filename, "rb");
|
||||||
|
uint32_t n_vocab = config->vocab_size;
|
||||||
|
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
||||||
|
vocab->id_to_token.resize(n_vocab);
|
||||||
|
for (uint32_t i=0; i<n_vocab; ++i) {
|
||||||
|
float_t score = file.read_f32();
|
||||||
|
uint32_t len = file.read_u32();
|
||||||
|
std::string tok = file.read_string(len);
|
||||||
|
vocab->id_to_token[i].tok = tok;
|
||||||
|
vocab->id_to_token[i].score = score;
|
||||||
|
vocab->token_to_id.emplace(tok, i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
llama_free(lctx);
|
|
||||||
llama_free_model(lmodel);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void stuff_karpathy_weights_into_gg(struct ggml_tensor * gg_weights, float * karpathy_weights){
|
void stuff_karpathy_weights_into_gg(struct ggml_tensor * gg_weights, float * karpathy_weights){
|
||||||
|
@ -684,7 +716,7 @@ void print_usage(int /*argc*/, char ** argv, const struct train_params * params)
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
fprintf(stderr, "options:\n");
|
fprintf(stderr, "options:\n");
|
||||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||||
fprintf(stderr, " --copy-vocab-from-model FNAME model path from which to copy vocab (default '%s')\n", params->fn_vocab_model);
|
fprintf(stderr, " --copy-vocab-from-model FNAME llama2.c vocabulary or ggml model path from which to copy vocab (default '%s')\n", params->fn_vocab_model);
|
||||||
fprintf(stderr, " --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model\n");
|
fprintf(stderr, " --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model\n");
|
||||||
fprintf(stderr, " --llama2c-output-model FNAME model path to save the converted llama2.c model (default %s')\n", params->fn_llama2c_output_model);
|
fprintf(stderr, " --llama2c-output-model FNAME model path to save the converted llama2.c model (default %s')\n", params->fn_llama2c_output_model);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
@ -764,7 +796,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_vocab vocab;
|
struct llama_vocab vocab;
|
||||||
load_vocab(params.fn_vocab_model, &vocab);
|
load_vocab(params.fn_vocab_model, &config, &vocab);
|
||||||
|
|
||||||
struct my_llama_model model;
|
struct my_llama_model model;
|
||||||
model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx);
|
model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue