models without vocabulary, llama.cpp part

This commit is contained in:
Michael Podvitskiy 2024-02-28 10:49:26 +01:00
parent 4f4258fbde
commit afa9d0953b
2 changed files with 63 additions and 37 deletions

View file

@ -3037,10 +3037,11 @@ static const char * llama_model_type_name(e_model type) {
static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
switch (type) { switch (type) {
case LLAMA_VOCAB_TYPE_SPM: return "SPM"; case LLAMA_VOCAB_TYPE_SPM: return "SPM";
case LLAMA_VOCAB_TYPE_BPE: return "BPE"; case LLAMA_VOCAB_TYPE_BPE: return "BPE";
case LLAMA_VOCAB_TYPE_WPM: return "WPM"; case LLAMA_VOCAB_TYPE_WPM: return "WPM";
default: return "unknown"; case LLAMA_VOCAB_TYPE_NO_VOCAB: return "no vocab";
default: return "unknown";
} }
} }
@ -3071,15 +3072,14 @@ static void llm_load_hparams(
// get general kv // get general kv
ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
// get hparams kv ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false);
ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS);
GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert);
@ -3410,30 +3410,25 @@ static void llm_load_vocab(
const auto kv = LLM_KV(model.arch); const auto kv = LLM_KV(model.arch);
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
if (token_idx == -1) {
throw std::runtime_error("cannot find tokenizer vocab in model file\n");
}
const float * scores = nullptr;
const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
if (score_idx != -1) {
scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
}
const int * toktypes = nullptr;
const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
if (toktype_idx != -1) {
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
}
// determine vocab type // determine vocab type
{ {
std::string tokenizer_name; std::string tokenizer_name;
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name); ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name);
if (tokenizer_name == "llama") { if (tokenizer_name == "no_vocab") {
vocab.type = LLAMA_VOCAB_TYPE_NO_VOCAB;
// default special tokens
vocab.special_bos_id = -1;
vocab.special_eos_id = -1;
vocab.special_unk_id = -1;
vocab.special_sep_id = -1;
vocab.special_pad_id = -1;
vocab.linefeed_id = -1;
return;
} else if (tokenizer_name == "llama") {
vocab.type = LLAMA_VOCAB_TYPE_SPM; vocab.type = LLAMA_VOCAB_TYPE_SPM;
// default special tokens // default special tokens
@ -3499,6 +3494,23 @@ static void llm_load_vocab(
} }
} }
const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
if (token_idx == -1) {
throw std::runtime_error("cannot find tokenizer vocab in model file\n");
}
const float * scores = nullptr;
const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
if (score_idx != -1) {
scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
}
const int * toktypes = nullptr;
const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
if (toktype_idx != -1) {
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
}
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
vocab.id_to_token.resize(n_vocab); vocab.id_to_token.resize(n_vocab);
@ -4725,7 +4737,8 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
llm_load_print_meta(ml, model); llm_load_print_meta(ml, model);
if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { if (model.vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB &&
model.hparams.n_vocab != model.vocab.id_to_token.size()) {
throw std::runtime_error("vocab size mismatch"); throw std::runtime_error("vocab size mismatch");
} }
@ -8714,26 +8727,32 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
} }
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL;
} }
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN;
} }
static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL;
} }
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
} }
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
} }
static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NO_VOCAB);
GGML_ASSERT(llama_is_byte_token(vocab, id)); GGML_ASSERT(llama_is_byte_token(vocab, id));
const auto& token_data = vocab.id_to_token.at(id); const auto& token_data = vocab.id_to_token.at(id);
switch (llama_vocab_get_type(vocab)) { switch (llama_vocab_get_type(vocab)) {
@ -8754,6 +8773,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
} }
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NO_VOCAB);
static const char * hex = "0123456789ABCDEF"; static const char * hex = "0123456789ABCDEF";
switch (llama_vocab_get_type(vocab)) { switch (llama_vocab_get_type(vocab)) {
case LLAMA_VOCAB_TYPE_SPM: { case LLAMA_VOCAB_TYPE_SPM: {
@ -9598,6 +9618,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
} }
} }
} break; } break;
case LLAMA_VOCAB_TYPE_NO_VOCAB:
GGML_ASSERT(false);
} }
return output; return output;
@ -13164,8 +13186,8 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
} }
void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch) { void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch) {
assert(n_threads); GGML_ASSERT(n_threads);
assert(n_threads_batch); GGML_ASSERT(n_threads_batch);
*n_threads = ctx->cparams.n_threads; *n_threads = ctx->cparams.n_threads;
*n_threads_batch = ctx->cparams.n_threads_batch; *n_threads_batch = ctx->cparams.n_threads_batch;
} }
@ -13268,14 +13290,17 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id
} }
const char * llama_token_get_text(const struct llama_model * model, llama_token token) { const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB);
return model->vocab.id_to_token[token].text.c_str(); return model->vocab.id_to_token[token].text.c_str();
} }
float llama_token_get_score(const struct llama_model * model, llama_token token) { float llama_token_get_score(const struct llama_model * model, llama_token token) {
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB);
return model->vocab.id_to_token[token].score; return model->vocab.id_to_token[token].score;
} }
llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) { llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB);
return model->vocab.id_to_token[token].type; return model->vocab.id_to_token[token].type;
} }

View file

@ -59,9 +59,10 @@ extern "C" {
typedef int32_t llama_seq_id; typedef int32_t llama_seq_id;
enum llama_vocab_type { enum llama_vocab_type {
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
LLAMA_VOCAB_TYPE_NO_VOCAB = 3, // For models without vocab
}; };
// note: these values should be synchronized with ggml_rope // note: these values should be synchronized with ggml_rope