llama2.c: support copying vocab from a llama gguf model file

This commit is contained in:
ochafik 2023-08-26 19:47:11 +01:00
parent 63174b8073
commit c88362d194

View file

@ -531,31 +531,87 @@ bool is_ggml_file(const char *filename) {
return magic == GGUF_MAGIC; return magic == GGUF_MAGIC;
} }
llama_vocab::ttype get_token_type(llama_vocab::id id, const llama_vocab::token &text) {
if (id == UNKNOWN_TOKEN_ID) return LLAMA_TOKEN_TYPE_UNKNOWN;
if (id == BOS_TOKEN_ID || id == EOS_TOKEN_ID) return LLAMA_TOKEN_TYPE_CONTROL;
unsigned char byte_val;
if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
return LLAMA_TOKEN_TYPE_BYTE;
}
return LLAMA_TOKEN_TYPE_NORMAL;
}
void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) { void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) {
if (is_ggml_file(filename)) {
struct ggml_context * ctx_data = NULL;
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ &ctx_data,
};
struct gguf_context * ctx = gguf_init_from_file(filename, params);
GGML_ASSERT(ctx != NULL);
const int model_idx = gguf_find_key(ctx, "tokenizer.ggml.model");
GGML_ASSERT(model_idx >= 0);
std::string tokenizer_name = gguf_get_val_str(ctx, model_idx);
GGML_ASSERT(tokenizer_name == "llama");
const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens");
GGML_ASSERT(token_idx >= 0);
const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores");
GGML_ASSERT(score_idx >= 0);
const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type");
GGML_ASSERT(toktype_idx >= 0);
const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
vocab->id_to_token.resize(n_vocab);
for (uint32_t i = 0; i < n_vocab; i++) {
std::string word = gguf_get_arr_str(ctx, token_idx, i);
vocab->token_to_id[word] = i;
auto & token_data = vocab->id_to_token[i];
token_data.text = std::move(word);
token_data.score = scores[i];
token_data.type = (llama_token_type) toktypes[i];
}
ggml_free(ctx_data);
gguf_free(ctx);
} else {
// assume llama2.c vocabulary // assume llama2.c vocabulary
printf("Assuming llama2.c vocabulary since %s is not a ggml file\n", filename); printf("Assuming llama2.c vocabulary since %s is not a gguf file\n", filename);
llama_file file(filename, "rb"); llama_file file(filename, "rb");
const int n_vocab = config->vocab_size; const int n_vocab = config->vocab_size;
/* uint32_t max_token_length = */ file.read_u32(); // unused /* uint32_t max_token_length = */ file.read_u32(); // unused
vocab->id_to_token.resize(n_vocab); vocab->id_to_token.resize(n_vocab);
for (int i=0; i<n_vocab; ++i) { for (llama_vocab::id id=0; id<n_vocab; ++id) {
float_t score = file.read_f32(); float_t score = file.read_f32();
uint32_t len = file.read_u32(); uint32_t len = file.read_u32();
std::string text = file.read_string(len); std::string text = file.read_string(len);
vocab->id_to_token[i].text = text;
vocab->id_to_token[i].score = score; unsigned char byte_val;
vocab->id_to_token[i].type = get_token_type(i, text); llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
vocab->token_to_id.emplace(text, i); if (id == UNKNOWN_TOKEN_ID) {
text = "<unk>";
type = LLAMA_TOKEN_TYPE_UNKNOWN;
} else if (id == BOS_TOKEN_ID) {
text = "<s>";
type = LLAMA_TOKEN_TYPE_CONTROL;
} else if (id == EOS_TOKEN_ID) {
text = "</s>";
type = LLAMA_TOKEN_TYPE_CONTROL;
} else if (text.empty()) {
type = LLAMA_TOKEN_TYPE_CONTROL;
} else if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
// Text of byte tokens is already in the expected format.
type = LLAMA_TOKEN_TYPE_BYTE;
} else {
type = LLAMA_TOKEN_TYPE_NORMAL;
}
vocab->id_to_token[id].text = text;
vocab->id_to_token[id].score = score;
vocab->id_to_token[id].type = type;
vocab->token_to_id.emplace(text, id);
}
} }
} }
@ -651,6 +707,8 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
gguf_set_val_u32(ctx, "tokenizer.ggml.unknown_token_id", UNKNOWN_TOKEN_ID); gguf_set_val_u32(ctx, "tokenizer.ggml.unknown_token_id", UNKNOWN_TOKEN_ID);
gguf_set_val_u32(ctx, "tokenizer.ggml.bos_token_id", BOS_TOKEN_ID); gguf_set_val_u32(ctx, "tokenizer.ggml.bos_token_id", BOS_TOKEN_ID);
gguf_set_val_u32(ctx, "tokenizer.ggml.eos_token_id", EOS_TOKEN_ID); gguf_set_val_u32(ctx, "tokenizer.ggml.eos_token_id", EOS_TOKEN_ID);
gguf_set_val_u32(ctx, "tokenizer.ggml.sep_token_id", -1);
gguf_set_val_u32(ctx, "tokenizer.ggml.pad_token_id", -1);
gguf_set_val_u32(ctx, "llama.context_length", model->hparams.n_ctx); gguf_set_val_u32(ctx, "llama.context_length", model->hparams.n_ctx);
gguf_set_val_u32(ctx, "llama.embedding_length", model->hparams.n_embd); gguf_set_val_u32(ctx, "llama.embedding_length", model->hparams.n_embd);