llama : vocab cleanup

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-08 22:17:15 +02:00
parent 9dd71e078f
commit cee3648ee3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 26 additions and 21 deletions

View file

@ -1663,7 +1663,7 @@ struct llama_sampler_dry {
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
for (llama_token token_id = 0; token_id < (llama_token) vocab.n_vocab(); token_id++) {
std::string word = vocab.detokenize({token_id}, true);
if (word.find(str) != std::string::npos) {
token_sequences.emplace(token_id, std::vector<llama_token>());

View file

@ -208,7 +208,7 @@ private:
return;
}
if (static_cast<uint32_t>(token) >= vocab.n_vocab) {
if (static_cast<uint32_t>(token) >= vocab.n_vocab()) {
return;
}
@ -734,7 +734,7 @@ struct llm_tokenizer_ugm : llm_tokenizer {
prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset;
}
for (uint32_t id = 0; id < vocab.n_vocab; ++id) {
for (uint32_t id = 0; id < vocab.n_vocab(); ++id) {
const auto & token_data = vocab.get_token_data(id);
if (vocab.is_normal(id)) {
@ -1119,7 +1119,7 @@ struct llm_tokenizer_rwkv : llm_tokenizer {
// For now, we decode the vocab here into the lookup we'll use for tokenization.
// build trie
for (uint32_t id = 0; id < vocab.n_vocab; ++id) {
for (uint32_t id = 0; id < vocab.n_vocab(); ++id) {
const auto & data = vocab.get_token_data(id);
const auto text = llama_unescape_rwkv_token(data.text);
token_matcher.insert((const char *) text.data(), text.size(), id);
@ -1204,6 +1204,8 @@ struct fragment_buffer_variant {
};
struct llama_vocab::impl {
uint32_t n_vocab = 0;
std::unordered_map<std::string, llama_token> token_to_id;
std::vector<token_data> id_to_token;
@ -1283,6 +1285,13 @@ llama_vocab::~llama_vocab() {
void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
struct gguf_context * ctx = ml.meta.get();
auto & n_vocab = pimpl->n_vocab;
auto & id_to_token = pimpl->id_to_token;
auto & token_to_id = pimpl->token_to_id;
auto & special_eog_ids = pimpl->special_eog_ids;
auto & cache_special_tokens = pimpl->cache_special_tokens;
auto & cache_token_to_piece = pimpl->cache_token_to_piece;
// determine vocab type
{
std::string tokenizer_model;
@ -1589,12 +1598,6 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
}
auto & id_to_token = pimpl->id_to_token;
auto & token_to_id = pimpl->token_to_id;
auto & special_eog_ids = pimpl->special_eog_ids;
auto & cache_special_tokens = pimpl->cache_special_tokens;
auto & cache_token_to_piece = pimpl->cache_token_to_piece;
n_vocab = gguf_get_arr_n(ctx, token_idx);
id_to_token.resize(n_vocab);
@ -1908,7 +1911,7 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
// build special tokens cache
{
for (llama_token id = 0; id < (llama_token)n_vocab; ++id) {
for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
cache_special_tokens.push_back(id);
}
@ -2002,6 +2005,10 @@ enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
return pre_type;
}
uint32_t llama_vocab::n_vocab() const {
return (uint32_t) pimpl->id_to_token.size();
}
std::string llama_vocab::type_name() const{
switch (type) {
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
@ -2366,8 +2373,8 @@ int llama_vocab::max_token_text_len() const {
void llama_vocab::print_info() const {
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str());
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) pimpl->bpe_ranks.size());
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, pimpl->n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) pimpl->bpe_ranks.size());
auto & id_to_token = pimpl->id_to_token;
auto & special_eog_ids = pimpl->special_eog_ids;

View file

@ -4,9 +4,6 @@
#include <string>
#include <vector>
#include <unordered_map>
#include <map>
#include <set>
#include <memory>
struct LLM_KV;
@ -19,8 +16,6 @@ struct llama_vocab {
llama_token_attr attr;
};
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
llama_vocab();
~llama_vocab();
@ -29,6 +24,9 @@ struct llama_vocab {
enum llama_vocab_type get_type() const;
enum llama_vocab_pre_type get_pre_type() const;
// TODO: how to deduplicate with llama_hparams.n_vocab ?
uint32_t n_vocab() const;
std::string type_name() const;
bool is_normal (llama_token id) const;

View file

@ -66,7 +66,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
model.print_info();
if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE &&
model.hparams.n_vocab != model.vocab.n_vocab) {
model.hparams.n_vocab != model.vocab.n_vocab()) {
throw std::runtime_error("vocab size mismatch");
}
@ -8474,7 +8474,7 @@ static int llama_decode_impl(
if (batch.token) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
@ -8809,7 +8809,7 @@ static int llama_encode_impl(
if (batch.token) {
for (uint32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}