llama2.c: support copying vocab from a llama gguf model file
This commit is contained in:
parent
63174b8073
commit
c88362d194
1 changed files with 82 additions and 24 deletions
|
@ -531,31 +531,87 @@ bool is_ggml_file(const char *filename) {
|
||||||
return magic == GGUF_MAGIC;
|
return magic == GGUF_MAGIC;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_vocab::ttype get_token_type(llama_vocab::id id, const llama_vocab::token &text) {
|
|
||||||
if (id == UNKNOWN_TOKEN_ID) return LLAMA_TOKEN_TYPE_UNKNOWN;
|
|
||||||
if (id == BOS_TOKEN_ID || id == EOS_TOKEN_ID) return LLAMA_TOKEN_TYPE_CONTROL;
|
|
||||||
unsigned char byte_val;
|
|
||||||
if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
|
|
||||||
return LLAMA_TOKEN_TYPE_BYTE;
|
|
||||||
}
|
|
||||||
return LLAMA_TOKEN_TYPE_NORMAL;
|
|
||||||
}
|
|
||||||
|
|
||||||
void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) {
|
void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) {
|
||||||
|
if (is_ggml_file(filename)) {
|
||||||
|
struct ggml_context * ctx_data = NULL;
|
||||||
|
|
||||||
|
struct gguf_init_params params = {
|
||||||
|
/*.no_alloc = */ false,
|
||||||
|
/*.ctx = */ &ctx_data,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct gguf_context * ctx = gguf_init_from_file(filename, params);
|
||||||
|
GGML_ASSERT(ctx != NULL);
|
||||||
|
|
||||||
|
const int model_idx = gguf_find_key(ctx, "tokenizer.ggml.model");
|
||||||
|
GGML_ASSERT(model_idx >= 0);
|
||||||
|
std::string tokenizer_name = gguf_get_val_str(ctx, model_idx);
|
||||||
|
GGML_ASSERT(tokenizer_name == "llama");
|
||||||
|
|
||||||
|
const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens");
|
||||||
|
GGML_ASSERT(token_idx >= 0);
|
||||||
|
|
||||||
|
const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores");
|
||||||
|
GGML_ASSERT(score_idx >= 0);
|
||||||
|
const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
|
||||||
|
|
||||||
|
const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type");
|
||||||
|
GGML_ASSERT(toktype_idx >= 0);
|
||||||
|
const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
|
||||||
|
|
||||||
|
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
|
||||||
|
|
||||||
|
vocab->id_to_token.resize(n_vocab);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||||
|
std::string word = gguf_get_arr_str(ctx, token_idx, i);
|
||||||
|
|
||||||
|
vocab->token_to_id[word] = i;
|
||||||
|
|
||||||
|
auto & token_data = vocab->id_to_token[i];
|
||||||
|
token_data.text = std::move(word);
|
||||||
|
token_data.score = scores[i];
|
||||||
|
token_data.type = (llama_token_type) toktypes[i];
|
||||||
|
}
|
||||||
|
ggml_free(ctx_data);
|
||||||
|
gguf_free(ctx);
|
||||||
|
} else {
|
||||||
// assume llama2.c vocabulary
|
// assume llama2.c vocabulary
|
||||||
printf("Assuming llama2.c vocabulary since %s is not a ggml file\n", filename);
|
printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
|
||||||
llama_file file(filename, "rb");
|
llama_file file(filename, "rb");
|
||||||
const int n_vocab = config->vocab_size;
|
const int n_vocab = config->vocab_size;
|
||||||
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
||||||
vocab->id_to_token.resize(n_vocab);
|
vocab->id_to_token.resize(n_vocab);
|
||||||
for (int i=0; i<n_vocab; ++i) {
|
for (llama_vocab::id id=0; id<n_vocab; ++id) {
|
||||||
float_t score = file.read_f32();
|
float_t score = file.read_f32();
|
||||||
uint32_t len = file.read_u32();
|
uint32_t len = file.read_u32();
|
||||||
std::string text = file.read_string(len);
|
std::string text = file.read_string(len);
|
||||||
vocab->id_to_token[i].text = text;
|
|
||||||
vocab->id_to_token[i].score = score;
|
unsigned char byte_val;
|
||||||
vocab->id_to_token[i].type = get_token_type(i, text);
|
llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
|
||||||
vocab->token_to_id.emplace(text, i);
|
if (id == UNKNOWN_TOKEN_ID) {
|
||||||
|
text = "<unk>";
|
||||||
|
type = LLAMA_TOKEN_TYPE_UNKNOWN;
|
||||||
|
} else if (id == BOS_TOKEN_ID) {
|
||||||
|
text = "<s>";
|
||||||
|
type = LLAMA_TOKEN_TYPE_CONTROL;
|
||||||
|
} else if (id == EOS_TOKEN_ID) {
|
||||||
|
text = "</s>";
|
||||||
|
type = LLAMA_TOKEN_TYPE_CONTROL;
|
||||||
|
} else if (text.empty()) {
|
||||||
|
type = LLAMA_TOKEN_TYPE_CONTROL;
|
||||||
|
} else if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
|
||||||
|
// Text of byte tokens is already in the expected format.
|
||||||
|
type = LLAMA_TOKEN_TYPE_BYTE;
|
||||||
|
} else {
|
||||||
|
type = LLAMA_TOKEN_TYPE_NORMAL;
|
||||||
|
}
|
||||||
|
|
||||||
|
vocab->id_to_token[id].text = text;
|
||||||
|
vocab->id_to_token[id].score = score;
|
||||||
|
vocab->id_to_token[id].type = type;
|
||||||
|
vocab->token_to_id.emplace(text, id);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -651,6 +707,8 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
|
||||||
gguf_set_val_u32(ctx, "tokenizer.ggml.unknown_token_id", UNKNOWN_TOKEN_ID);
|
gguf_set_val_u32(ctx, "tokenizer.ggml.unknown_token_id", UNKNOWN_TOKEN_ID);
|
||||||
gguf_set_val_u32(ctx, "tokenizer.ggml.bos_token_id", BOS_TOKEN_ID);
|
gguf_set_val_u32(ctx, "tokenizer.ggml.bos_token_id", BOS_TOKEN_ID);
|
||||||
gguf_set_val_u32(ctx, "tokenizer.ggml.eos_token_id", EOS_TOKEN_ID);
|
gguf_set_val_u32(ctx, "tokenizer.ggml.eos_token_id", EOS_TOKEN_ID);
|
||||||
|
gguf_set_val_u32(ctx, "tokenizer.ggml.sep_token_id", -1);
|
||||||
|
gguf_set_val_u32(ctx, "tokenizer.ggml.pad_token_id", -1);
|
||||||
|
|
||||||
gguf_set_val_u32(ctx, "llama.context_length", model->hparams.n_ctx);
|
gguf_set_val_u32(ctx, "llama.context_length", model->hparams.n_ctx);
|
||||||
gguf_set_val_u32(ctx, "llama.embedding_length", model->hparams.n_embd);
|
gguf_set_val_u32(ctx, "llama.embedding_length", model->hparams.n_embd);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue