llama2.c: support copying vocab from a llama gguf model file
This commit is contained in:
parent
63174b8073
commit
c88362d194
1 changed files with 82 additions and 24 deletions
|
@ -531,31 +531,87 @@ bool is_ggml_file(const char *filename) {
|
|||
return magic == GGUF_MAGIC;
|
||||
}
|
||||
|
||||
llama_vocab::ttype get_token_type(llama_vocab::id id, const llama_vocab::token &text) {
|
||||
if (id == UNKNOWN_TOKEN_ID) return LLAMA_TOKEN_TYPE_UNKNOWN;
|
||||
if (id == BOS_TOKEN_ID || id == EOS_TOKEN_ID) return LLAMA_TOKEN_TYPE_CONTROL;
|
||||
unsigned char byte_val;
|
||||
if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
|
||||
return LLAMA_TOKEN_TYPE_BYTE;
|
||||
}
|
||||
return LLAMA_TOKEN_TYPE_NORMAL;
|
||||
}
|
||||
|
||||
void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) {
|
||||
// assume llama2.c vocabulary
|
||||
printf("Assuming llama2.c vocabulary since %s is not a ggml file\n", filename);
|
||||
llama_file file(filename, "rb");
|
||||
const int n_vocab = config->vocab_size;
|
||||
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
||||
vocab->id_to_token.resize(n_vocab);
|
||||
for (int i=0; i<n_vocab; ++i) {
|
||||
float_t score = file.read_f32();
|
||||
uint32_t len = file.read_u32();
|
||||
std::string text = file.read_string(len);
|
||||
vocab->id_to_token[i].text = text;
|
||||
vocab->id_to_token[i].score = score;
|
||||
vocab->id_to_token[i].type = get_token_type(i, text);
|
||||
vocab->token_to_id.emplace(text, i);
|
||||
if (is_ggml_file(filename)) {
|
||||
struct ggml_context * ctx_data = NULL;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ false,
|
||||
/*.ctx = */ &ctx_data,
|
||||
};
|
||||
|
||||
struct gguf_context * ctx = gguf_init_from_file(filename, params);
|
||||
GGML_ASSERT(ctx != NULL);
|
||||
|
||||
const int model_idx = gguf_find_key(ctx, "tokenizer.ggml.model");
|
||||
GGML_ASSERT(model_idx >= 0);
|
||||
std::string tokenizer_name = gguf_get_val_str(ctx, model_idx);
|
||||
GGML_ASSERT(tokenizer_name == "llama");
|
||||
|
||||
const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens");
|
||||
GGML_ASSERT(token_idx >= 0);
|
||||
|
||||
const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores");
|
||||
GGML_ASSERT(score_idx >= 0);
|
||||
const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
|
||||
|
||||
const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type");
|
||||
GGML_ASSERT(toktype_idx >= 0);
|
||||
const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
|
||||
|
||||
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
|
||||
|
||||
vocab->id_to_token.resize(n_vocab);
|
||||
|
||||
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||
std::string word = gguf_get_arr_str(ctx, token_idx, i);
|
||||
|
||||
vocab->token_to_id[word] = i;
|
||||
|
||||
auto & token_data = vocab->id_to_token[i];
|
||||
token_data.text = std::move(word);
|
||||
token_data.score = scores[i];
|
||||
token_data.type = (llama_token_type) toktypes[i];
|
||||
}
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx);
|
||||
} else {
|
||||
// assume llama2.c vocabulary
|
||||
printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
|
||||
llama_file file(filename, "rb");
|
||||
const int n_vocab = config->vocab_size;
|
||||
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
||||
vocab->id_to_token.resize(n_vocab);
|
||||
for (llama_vocab::id id=0; id<n_vocab; ++id) {
|
||||
float_t score = file.read_f32();
|
||||
uint32_t len = file.read_u32();
|
||||
std::string text = file.read_string(len);
|
||||
|
||||
unsigned char byte_val;
|
||||
llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
|
||||
if (id == UNKNOWN_TOKEN_ID) {
|
||||
text = "<unk>";
|
||||
type = LLAMA_TOKEN_TYPE_UNKNOWN;
|
||||
} else if (id == BOS_TOKEN_ID) {
|
||||
text = "<s>";
|
||||
type = LLAMA_TOKEN_TYPE_CONTROL;
|
||||
} else if (id == EOS_TOKEN_ID) {
|
||||
text = "</s>";
|
||||
type = LLAMA_TOKEN_TYPE_CONTROL;
|
||||
} else if (text.empty()) {
|
||||
type = LLAMA_TOKEN_TYPE_CONTROL;
|
||||
} else if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
|
||||
// Text of byte tokens is already in the expected format.
|
||||
type = LLAMA_TOKEN_TYPE_BYTE;
|
||||
} else {
|
||||
type = LLAMA_TOKEN_TYPE_NORMAL;
|
||||
}
|
||||
|
||||
vocab->id_to_token[id].text = text;
|
||||
vocab->id_to_token[id].score = score;
|
||||
vocab->id_to_token[id].type = type;
|
||||
vocab->token_to_id.emplace(text, id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -651,6 +707,8 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
|
|||
gguf_set_val_u32(ctx, "tokenizer.ggml.unknown_token_id", UNKNOWN_TOKEN_ID);
|
||||
gguf_set_val_u32(ctx, "tokenizer.ggml.bos_token_id", BOS_TOKEN_ID);
|
||||
gguf_set_val_u32(ctx, "tokenizer.ggml.eos_token_id", EOS_TOKEN_ID);
|
||||
gguf_set_val_u32(ctx, "tokenizer.ggml.sep_token_id", -1);
|
||||
gguf_set_val_u32(ctx, "tokenizer.ggml.pad_token_id", -1);
|
||||
|
||||
gguf_set_val_u32(ctx, "llama.context_length", model->hparams.n_ctx);
|
||||
gguf_set_val_u32(ctx, "llama.embedding_length", model->hparams.n_embd);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue