llama : vocab cleanup
ggml-ci
This commit is contained in:
parent
f784700c31
commit
ad1923a0ce
4 changed files with 26 additions and 21 deletions
|
@ -1663,7 +1663,7 @@ struct llama_sampler_dry {
|
|||
|
||||
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
|
||||
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
|
||||
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
|
||||
for (llama_token token_id = 0; token_id < (llama_token) vocab.n_vocab(); token_id++) {
|
||||
std::string word = vocab.detokenize({token_id}, true);
|
||||
if (word.find(str) != std::string::npos) {
|
||||
token_sequences.emplace(token_id, std::vector<llama_token>());
|
||||
|
|
|
@ -208,7 +208,7 @@ private:
|
|||
return;
|
||||
}
|
||||
|
||||
if (static_cast<uint32_t>(token) >= vocab.n_vocab) {
|
||||
if (static_cast<uint32_t>(token) >= vocab.n_vocab()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -734,7 +734,7 @@ struct llm_tokenizer_ugm : llm_tokenizer {
|
|||
prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset;
|
||||
}
|
||||
|
||||
for (uint32_t id = 0; id < vocab.n_vocab; ++id) {
|
||||
for (uint32_t id = 0; id < vocab.n_vocab(); ++id) {
|
||||
const auto & token_data = vocab.get_token_data(id);
|
||||
|
||||
if (vocab.is_normal(id)) {
|
||||
|
@ -1119,7 +1119,7 @@ struct llm_tokenizer_rwkv : llm_tokenizer {
|
|||
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
||||
|
||||
// build trie
|
||||
for (uint32_t id = 0; id < vocab.n_vocab; ++id) {
|
||||
for (uint32_t id = 0; id < vocab.n_vocab(); ++id) {
|
||||
const auto & data = vocab.get_token_data(id);
|
||||
const auto text = llama_unescape_rwkv_token(data.text);
|
||||
token_matcher.insert((const char *) text.data(), text.size(), id);
|
||||
|
@ -1204,6 +1204,8 @@ struct fragment_buffer_variant {
|
|||
};
|
||||
|
||||
struct llama_vocab::impl {
|
||||
uint32_t n_vocab = 0;
|
||||
|
||||
std::unordered_map<std::string, llama_token> token_to_id;
|
||||
std::vector<token_data> id_to_token;
|
||||
|
||||
|
@ -1283,6 +1285,13 @@ llama_vocab::~llama_vocab() {
|
|||
void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
struct gguf_context * ctx = ml.meta.get();
|
||||
|
||||
auto & n_vocab = pimpl->n_vocab;
|
||||
auto & id_to_token = pimpl->id_to_token;
|
||||
auto & token_to_id = pimpl->token_to_id;
|
||||
auto & special_eog_ids = pimpl->special_eog_ids;
|
||||
auto & cache_special_tokens = pimpl->cache_special_tokens;
|
||||
auto & cache_token_to_piece = pimpl->cache_token_to_piece;
|
||||
|
||||
// determine vocab type
|
||||
{
|
||||
std::string tokenizer_model;
|
||||
|
@ -1589,12 +1598,6 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
|
||||
}
|
||||
|
||||
auto & id_to_token = pimpl->id_to_token;
|
||||
auto & token_to_id = pimpl->token_to_id;
|
||||
auto & special_eog_ids = pimpl->special_eog_ids;
|
||||
auto & cache_special_tokens = pimpl->cache_special_tokens;
|
||||
auto & cache_token_to_piece = pimpl->cache_token_to_piece;
|
||||
|
||||
n_vocab = gguf_get_arr_n(ctx, token_idx);
|
||||
id_to_token.resize(n_vocab);
|
||||
|
||||
|
@ -1908,7 +1911,7 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
|
||||
// build special tokens cache
|
||||
{
|
||||
for (llama_token id = 0; id < (llama_token)n_vocab; ++id) {
|
||||
for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
|
||||
if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
|
||||
cache_special_tokens.push_back(id);
|
||||
}
|
||||
|
@ -2002,6 +2005,10 @@ enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
|
|||
return pre_type;
|
||||
}
|
||||
|
||||
uint32_t llama_vocab::n_vocab() const {
|
||||
return (uint32_t) pimpl->id_to_token.size();
|
||||
}
|
||||
|
||||
std::string llama_vocab::type_name() const{
|
||||
switch (type) {
|
||||
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
|
||||
|
@ -2366,8 +2373,8 @@ int llama_vocab::max_token_text_len() const {
|
|||
|
||||
void llama_vocab::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str());
|
||||
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, n_vocab);
|
||||
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) pimpl->bpe_ranks.size());
|
||||
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, pimpl->n_vocab);
|
||||
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) pimpl->bpe_ranks.size());
|
||||
|
||||
auto & id_to_token = pimpl->id_to_token;
|
||||
auto & special_eog_ids = pimpl->special_eog_ids;
|
||||
|
|
|
@ -4,9 +4,6 @@
|
|||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
|
||||
struct LLM_KV;
|
||||
|
@ -19,8 +16,6 @@ struct llama_vocab {
|
|||
llama_token_attr attr;
|
||||
};
|
||||
|
||||
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
|
||||
|
||||
llama_vocab();
|
||||
~llama_vocab();
|
||||
|
||||
|
@ -29,6 +24,9 @@ struct llama_vocab {
|
|||
enum llama_vocab_type get_type() const;
|
||||
enum llama_vocab_pre_type get_pre_type() const;
|
||||
|
||||
// TODO: how to deduplicate with llama_hparams.n_vocab ?
|
||||
uint32_t n_vocab() const;
|
||||
|
||||
std::string type_name() const;
|
||||
|
||||
bool is_normal (llama_token id) const;
|
||||
|
|
|
@ -66,7 +66,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
|
|||
model.print_info();
|
||||
|
||||
if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE &&
|
||||
model.hparams.n_vocab != model.vocab.n_vocab) {
|
||||
model.hparams.n_vocab != model.vocab.n_vocab()) {
|
||||
throw std::runtime_error("vocab size mismatch");
|
||||
}
|
||||
|
||||
|
@ -8349,7 +8349,7 @@ static int llama_decode_impl(
|
|||
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
|
@ -8684,7 +8684,7 @@ static int llama_encode_impl(
|
|||
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_vocab()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue