llama : advanced BPE tokenizer based on ggllm.cpp imlpementation
This commit is contained in:
parent
c3f8a6e49f
commit
3bfb720642
4 changed files with 327 additions and 58 deletions
|
@ -43,7 +43,7 @@ static bool is_interacting = false;
|
|||
void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
if (!is_interacting) {
|
||||
is_interacting=true;
|
||||
is_interacting = true;
|
||||
} else {
|
||||
console::cleanup();
|
||||
printf("\n");
|
||||
|
@ -189,10 +189,12 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> embd_inp;
|
||||
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
|
||||
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||
embd_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
|
||||
} else {
|
||||
embd_inp = session_tokens;
|
||||
}
|
||||
|
@ -203,9 +205,9 @@ int main(int argc, char ** argv) {
|
|||
int original_prompt_len = 0;
|
||||
if (ctx_guidance) {
|
||||
params.cfg_negative_prompt.insert(0, 1, ' ');
|
||||
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
|
||||
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, is_spm);
|
||||
|
||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, is_spm);
|
||||
original_prompt_len = original_inp.size();
|
||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
||||
}
|
||||
|
@ -252,7 +254,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// prefix & suffix for instruct mode
|
||||
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
|
||||
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", is_spm);
|
||||
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
|
||||
|
||||
// in instruct mode, we inject a prefix and a suffix to each input by the user
|
||||
|
|
|
@ -28,7 +28,6 @@ std::vector<float> softmax(const std::vector<float>& logits) {
|
|||
}
|
||||
|
||||
void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
|
||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
|
@ -38,7 +37,11 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||
fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride);
|
||||
return;
|
||||
}
|
||||
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||
const bool add_bos = is_spm;
|
||||
|
||||
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
const int calc_chunk = params.n_ctx;
|
||||
|
||||
|
@ -86,7 +89,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||
const auto token_org = tokens[batch_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (j == 0) {
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_token_bos(ctx);
|
||||
}
|
||||
|
||||
|
@ -136,7 +139,6 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
|||
}
|
||||
|
||||
void perplexity(llama_context * ctx, const gpt_params & params) {
|
||||
|
||||
if (params.ppl_stride > 0) {
|
||||
perplexity_v2(ctx, params);
|
||||
return;
|
||||
|
@ -146,7 +148,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||
// Output: `perplexity: 13.5106 [114/114]`
|
||||
// BOS tokens will be added for each chunk before eval
|
||||
auto tokens = ::llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||
const bool add_bos = is_spm;
|
||||
|
||||
auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
const int n_chunk_max = tokens.size() / params.n_ctx;
|
||||
|
||||
|
@ -177,7 +183,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||
const auto token_org = tokens[batch_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (j == 0) {
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_token_bos(ctx);
|
||||
}
|
||||
|
||||
|
@ -295,8 +301,10 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
size_t hs_task_count = prompt_lines.size()/6;
|
||||
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
|
||||
|
||||
const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
|
||||
|
||||
// This is needed as usual for LLaMA models
|
||||
bool prepend_bos = true;
|
||||
const bool add_bos = is_spm;
|
||||
|
||||
// Number of tasks to use when computing the score
|
||||
if ( params.hellaswag_tasks < hs_task_count ) {
|
||||
|
@ -352,14 +360,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
std::vector<float> tok_logits(n_vocab);
|
||||
|
||||
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
|
||||
|
||||
// Tokenize the context to count tokens
|
||||
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
|
||||
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
|
||||
size_t context_size = context_embd.size();
|
||||
|
||||
// Do the 1st ending
|
||||
// In this case we include the context when evaluating
|
||||
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
|
||||
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
|
||||
auto query_size = query_embd.size();
|
||||
//printf("First query: %d\n",(int)query_size);
|
||||
|
||||
|
|
302
llama.cpp
302
llama.cpp
|
@ -72,6 +72,7 @@
|
|||
#include <numeric>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
|
@ -112,6 +113,15 @@ static size_t utf8_len(char src) {
|
|||
return lookup[highbits];
|
||||
}
|
||||
|
||||
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
||||
for (size_t pos = 0; ; pos += replace.length()) {
|
||||
pos = s.find(search, pos);
|
||||
if (pos == std::string::npos) break;
|
||||
s.erase(pos, search.length());
|
||||
s.insert(pos, replace);
|
||||
}
|
||||
}
|
||||
|
||||
static void zeros(std::ofstream & file, size_t n) {
|
||||
char zero = 0;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
|
@ -929,11 +939,13 @@ struct llama_vocab {
|
|||
ttype type;
|
||||
};
|
||||
|
||||
llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||
|
||||
std::unordered_map<token, id> token_to_id;
|
||||
std::vector<token_data> id_to_token;
|
||||
|
||||
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
|
||||
|
||||
// default LLaMA special tokens
|
||||
id special_bos_id = 1;
|
||||
id special_eos_id = 2;
|
||||
|
@ -942,6 +954,20 @@ struct llama_vocab {
|
|||
id special_pad_id = -1;
|
||||
|
||||
id linefeed_id = 13;
|
||||
|
||||
int find_bpe_rank(std::string token_left, std::string token_right) const {
|
||||
replace_all(token_left, " ", "Ġ");
|
||||
replace_all(token_left, "\n", "Ċ");
|
||||
replace_all(token_right, " ", "Ġ");
|
||||
replace_all(token_right, "\n", "Ċ");
|
||||
|
||||
auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
|
||||
if (it == bpe_ranks.end()) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return it->second;
|
||||
}
|
||||
};
|
||||
|
||||
struct llama_model {
|
||||
|
@ -1634,6 +1660,30 @@ static void llm_load_vocab(
|
|||
vocab.type = LLAMA_VOCAB_TYPE_SPM;
|
||||
} else if (tokenizer_name == "gpt2") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||
|
||||
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
|
||||
if (merges_keyidx == -1) {
|
||||
throw std::runtime_error("cannot find tokenizer merges in model file\n");
|
||||
}
|
||||
|
||||
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
|
||||
|
||||
for (int i = 0; i < n_merges; i++) {
|
||||
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
|
||||
|
||||
std::string first;
|
||||
std::string second;
|
||||
|
||||
const size_t pos = word.find(' ', 1);
|
||||
|
||||
if (pos != std::string::npos) {
|
||||
first = word.substr(0, pos);
|
||||
second = word.substr(pos + 1);
|
||||
}
|
||||
|
||||
// populate bpe ranks
|
||||
vocab.bpe_ranks.emplace(std::make_pair(first, second), i);
|
||||
}
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
||||
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
|
||||
|
@ -1654,7 +1704,6 @@ static void llm_load_vocab(
|
|||
token_data.text = std::move(word);
|
||||
token_data.score = scores[i];
|
||||
token_data.type = (llama_token_type) toktypes[i];
|
||||
|
||||
}
|
||||
|
||||
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
|
||||
|
@ -1677,6 +1726,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
|
||||
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
|
||||
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
||||
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
|
||||
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
|
||||
|
@ -2954,41 +3004,43 @@ static std::string llama_unescape_whitespace(const std::string& word) {
|
|||
return word;
|
||||
}
|
||||
|
||||
// original implementation:
|
||||
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||
struct llama_tokenizer {
|
||||
struct sp_symbol {
|
||||
struct llm_symbol {
|
||||
using index = int;
|
||||
index prev;
|
||||
index next;
|
||||
const char * text;
|
||||
size_t n;
|
||||
};
|
||||
};
|
||||
|
||||
static_assert(std::is_trivially_copyable<sp_symbol>::value, "sp_symbol is not trivially copyable");
|
||||
static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not trivially copyable");
|
||||
|
||||
struct sp_bigram {
|
||||
// SPM tokenizer
|
||||
// original implementation:
|
||||
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
||||
|
||||
struct llm_bigram_spm {
|
||||
struct comparator {
|
||||
bool operator()(sp_bigram & l, sp_bigram & r) {
|
||||
bool operator()(llm_bigram_spm & l, llm_bigram_spm & r) {
|
||||
return (l.score < r.score) || (l.score == r.score && l.left > r.left);
|
||||
}
|
||||
};
|
||||
using queue_storage = std::vector<sp_bigram>;
|
||||
using queue = std::priority_queue<sp_bigram, queue_storage, comparator>;
|
||||
sp_symbol::index left;
|
||||
sp_symbol::index right;
|
||||
using queue_storage = std::vector<llm_bigram_spm>;
|
||||
using queue = std::priority_queue<llm_bigram_spm, queue_storage, comparator>;
|
||||
llm_symbol::index left;
|
||||
llm_symbol::index right;
|
||||
float score;
|
||||
size_t size;
|
||||
};
|
||||
};
|
||||
|
||||
llama_tokenizer(const llama_vocab & vocab): vocab(vocab) {}
|
||||
struct llm_tokenizer_spm {
|
||||
llm_tokenizer_spm(const llama_vocab & vocab): vocab(vocab) {}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
// split string into utf8 chars
|
||||
int index = 0;
|
||||
size_t offs = 0;
|
||||
while (offs < text.size()) {
|
||||
sp_symbol sym;
|
||||
llm_symbol sym;
|
||||
size_t len = utf8_len(text[offs]);
|
||||
GGML_ASSERT(offs + len <= text.size());
|
||||
sym.text = text.c_str() + offs;
|
||||
|
@ -3043,7 +3095,7 @@ struct llama_tokenizer {
|
|||
}
|
||||
|
||||
private:
|
||||
void resegment(sp_symbol & symbol, std::vector<llama_vocab::id> & output) {
|
||||
void resegment(llm_symbol & symbol, std::vector<llama_vocab::id> & output) {
|
||||
auto text = std::string(symbol.text, symbol.n);
|
||||
auto token = vocab.token_to_id.find(text);
|
||||
|
||||
|
@ -3086,7 +3138,7 @@ private:
|
|||
|
||||
const auto & tok_data = vocab.id_to_token[(*token).second];
|
||||
|
||||
sp_bigram bigram;
|
||||
llm_bigram_spm bigram;
|
||||
bigram.left = left;
|
||||
bigram.right = right;
|
||||
bigram.score = tok_data.score;
|
||||
|
@ -3100,19 +3152,208 @@ private:
|
|||
|
||||
const llama_vocab & vocab;
|
||||
|
||||
std::vector<sp_symbol> symbols;
|
||||
sp_bigram::queue work_queue;
|
||||
std::vector<llm_symbol> symbols;
|
||||
llm_bigram_spm::queue work_queue;
|
||||
|
||||
std::map<std::string, std::pair<int, int>> rev_merge;
|
||||
};
|
||||
|
||||
// BPE tokenizer
|
||||
// adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License]
|
||||
// tried to simplify unicode stuff, so most likely does not work 100% correctly!
|
||||
|
||||
// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
|
||||
|
||||
struct llm_bigram_bpe {
|
||||
struct comparator {
|
||||
bool operator()(llm_bigram_bpe & l, llm_bigram_bpe & r) {
|
||||
return l.rank > r.rank || (l.rank == r.rank && l.left > r.left);
|
||||
}
|
||||
};
|
||||
|
||||
using queue_storage = std::vector<llm_bigram_bpe>;
|
||||
using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>;
|
||||
llm_symbol::index left;
|
||||
llm_symbol::index right;
|
||||
std::string text;
|
||||
int rank;
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct llm_tokenizer_bpe {
|
||||
llm_tokenizer_bpe(const llama_vocab & vocab, bool g2ws): vocab(vocab) { flag_g2ws = g2ws; }
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||
int final_prev_index = -1;
|
||||
auto word_collection = bpe_gpt2_preprocess(text);
|
||||
|
||||
symbols_final.clear();
|
||||
|
||||
for (auto & word : word_collection) {
|
||||
work_queue = llm_bigram_bpe::queue();
|
||||
symbols.clear();
|
||||
|
||||
int index = 0;
|
||||
size_t offset = 0;
|
||||
|
||||
while (offset < word.size()) {
|
||||
llm_symbol sym;
|
||||
size_t char_len = std::min(word.size() - offset, (size_t) ::utf8_len(word[offset]));
|
||||
sym.text = word.c_str() + offset;
|
||||
sym.n = 1;
|
||||
sym.n = char_len;
|
||||
offset += sym.n;
|
||||
sym.prev = index - 1;
|
||||
sym.next = offset == word.size() ? -1 : index + 1;
|
||||
index++;
|
||||
symbols.emplace_back(sym);
|
||||
}
|
||||
for (size_t i = 1; i < symbols.size(); ++i) {
|
||||
add_new_bigram(i - 1, i);
|
||||
}
|
||||
|
||||
// build token(s)
|
||||
while (!work_queue.empty()) {
|
||||
auto bigram = work_queue.top();
|
||||
work_queue.pop();
|
||||
|
||||
auto & left_symbol = symbols[bigram.left];
|
||||
auto & right_symbol = symbols[bigram.right];
|
||||
|
||||
if (left_symbol.n == 0 || right_symbol.n == 0) {
|
||||
continue;
|
||||
}
|
||||
std::string left_token = std::string(left_symbol.text, left_symbol.n);
|
||||
std::string right_token = std::string(right_symbol.text, right_symbol.n);
|
||||
if (left_token + right_token != bigram.text) {
|
||||
continue; // Skip this bigram if it's outdated
|
||||
}
|
||||
|
||||
// merge the right sym into the left one
|
||||
left_symbol.n += right_symbol.n;
|
||||
right_symbol.n = 0;
|
||||
|
||||
// remove the right sym from the chain
|
||||
left_symbol.next = right_symbol.next;
|
||||
if (right_symbol.next >= 0) {
|
||||
symbols[right_symbol.next].prev = bigram.left;
|
||||
}
|
||||
|
||||
add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol
|
||||
add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol
|
||||
}
|
||||
|
||||
// add the fnished tokens to the final list keeping correct order for next and prev
|
||||
for (auto & sym : symbols) {
|
||||
if (sym.n > 0) {
|
||||
sym.prev = final_prev_index;
|
||||
sym.next = -1;
|
||||
if (final_prev_index != -1) {
|
||||
symbols_final[final_prev_index].next = symbols_final.size();
|
||||
}
|
||||
symbols_final.emplace_back(sym);
|
||||
final_prev_index = symbols_final.size() - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
symbols = symbols_final;
|
||||
|
||||
if (!symbols.empty()) {
|
||||
for (int i = 0; i != -1; i = symbols[i].next) {
|
||||
auto & symbol = symbols[i];
|
||||
if (symbol.n == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string str = std::string(symbol.text, symbol.n);
|
||||
const auto token = vocab.token_to_id.find(str);
|
||||
|
||||
if (token == vocab.token_to_id.end()) {
|
||||
for (auto j = str.begin(); j != str.end(); ++j) {
|
||||
std::string byte_str(1, *j);
|
||||
auto token_multibyte = vocab.token_to_id.find(byte_str);
|
||||
if (token_multibyte == vocab.token_to_id.end()) {
|
||||
fprintf(stderr,"ERROR: byte not found in vocab: '%s'\n", byte_str.c_str());
|
||||
}
|
||||
output.push_back((*token_multibyte).second);
|
||||
}
|
||||
} else {
|
||||
output.push_back((*token).second);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void add_new_bigram(int left, int right) {
|
||||
if (left == -1 || right == -1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string left_token = std::string(symbols[left].text, symbols[left].n);
|
||||
std::string right_token = std::string(symbols[right].text, symbols[right].n);
|
||||
|
||||
int rank_found = -1;
|
||||
|
||||
rank_found = vocab.find_bpe_rank(left_token, right_token);
|
||||
|
||||
if (rank_found < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
llm_bigram_bpe bigram;
|
||||
|
||||
bigram.left = left;
|
||||
bigram.right = right;
|
||||
bigram.text = left_token + right_token;
|
||||
bigram.size = left_token.size() + right_token.size();
|
||||
bigram.rank = rank_found;
|
||||
|
||||
work_queue.push(bigram);
|
||||
}
|
||||
|
||||
// probably not 100% correct
|
||||
static std::vector<std::string> bpe_gpt2_preprocess(std::string text) {
|
||||
std::vector<std::string> words;
|
||||
|
||||
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||
const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
const std::regex re(pattern);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(text, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
text = m.suffix();
|
||||
}
|
||||
|
||||
return words;
|
||||
}
|
||||
|
||||
bool flag_g2ws = false;
|
||||
|
||||
const llama_vocab & vocab;
|
||||
|
||||
std::vector<llm_symbol> symbols;
|
||||
std::vector<llm_symbol> symbols_final;
|
||||
|
||||
llm_bigram_bpe::queue work_queue;
|
||||
};
|
||||
|
||||
static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) {
|
||||
llama_tokenizer tokenizer(vocab);
|
||||
std::vector<llama_vocab::id> output;
|
||||
|
||||
if (raw_text.empty()) {
|
||||
return output;
|
||||
}
|
||||
|
||||
switch (vocab.type) {
|
||||
case LLAMA_VOCAB_TYPE_SPM:
|
||||
{
|
||||
llm_tokenizer_spm tokenizer(vocab);
|
||||
|
||||
if (bos) {
|
||||
output.push_back(vocab.special_bos_id);
|
||||
}
|
||||
|
@ -3125,6 +3366,19 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||
}
|
||||
|
||||
tokenizer.tokenize(text, output);
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_BPE:
|
||||
{
|
||||
llm_tokenizer_bpe tokenizer(vocab, escape);
|
||||
|
||||
if (bos && vocab.special_bos_id != -1) {
|
||||
output.push_back(vocab.special_bos_id);
|
||||
}
|
||||
|
||||
tokenizer.tokenize(raw_text, output);
|
||||
} break;
|
||||
};
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
|
@ -4988,6 +5242,10 @@ int llama_n_embd(const struct llama_context * ctx) {
|
|||
return ctx->model.hparams.n_embd;
|
||||
}
|
||||
|
||||
enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) {
|
||||
return ctx->model.vocab.type;
|
||||
}
|
||||
|
||||
int llama_model_n_vocab(const struct llama_model * model) {
|
||||
return model->vocab.id_to_token.size();
|
||||
}
|
||||
|
|
2
llama.h
2
llama.h
|
@ -247,6 +247,8 @@ extern "C" {
|
|||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
|
||||
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
|
||||
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue