llama : add handling of byte tokens in UGM tokenizer (same as in SPM)
llama : fix preventing crashes when precompiled_charsmap is not present
This commit is contained in:
parent
c2c799cefa
commit
f4c03c0966
1 changed files with 39 additions and 35 deletions
74
llama.cpp
74
llama.cpp
|
@ -13335,7 +13335,8 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
|
||||||
GGML_ASSERT(llama_is_byte_token(vocab, id));
|
GGML_ASSERT(llama_is_byte_token(vocab, id));
|
||||||
const auto & token_data = vocab.id_to_token.at(id);
|
const auto & token_data = vocab.id_to_token.at(id);
|
||||||
switch (llama_vocab_get_type(vocab)) {
|
switch (llama_vocab_get_type(vocab)) {
|
||||||
case LLAMA_VOCAB_TYPE_SPM: {
|
case LLAMA_VOCAB_TYPE_SPM:
|
||||||
|
case LLAMA_VOCAB_TYPE_UGM: {
|
||||||
auto buf = token_data.text.substr(3, 2);
|
auto buf = token_data.text.substr(3, 2);
|
||||||
return strtol(buf.c_str(), NULL, 16);
|
return strtol(buf.c_str(), NULL, 16);
|
||||||
}
|
}
|
||||||
|
@ -13355,7 +13356,8 @@ static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
|
||||||
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
|
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
|
||||||
static const char * hex = "0123456789ABCDEF";
|
static const char * hex = "0123456789ABCDEF";
|
||||||
switch (llama_vocab_get_type(vocab)) {
|
switch (llama_vocab_get_type(vocab)) {
|
||||||
case LLAMA_VOCAB_TYPE_SPM: {
|
case LLAMA_VOCAB_TYPE_SPM:
|
||||||
|
case LLAMA_VOCAB_TYPE_UGM: {
|
||||||
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
|
const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
|
||||||
auto token = vocab.token_to_id.find(buf);
|
auto token = vocab.token_to_id.find(buf);
|
||||||
if (token != vocab.token_to_id.end()) {
|
if (token != vocab.token_to_id.end()) {
|
||||||
|
@ -14242,36 +14244,38 @@ private:
|
||||||
size_t longest_prefix_length = 0;
|
size_t longest_prefix_length = 0;
|
||||||
size_t longest_prefix_offset = 0;
|
size_t longest_prefix_offset = 0;
|
||||||
|
|
||||||
struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
|
if (xcda_array_size > 0) {
|
||||||
|
struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
|
||||||
|
|
||||||
// Find the longest normalized sequence matching the input prefix by walking
|
// Find the longest normalized sequence matching the input prefix by walking
|
||||||
// the XOR-compressed compact double array (XCDA) starting from the root node
|
// the XOR-compressed compact double array (XCDA) starting from the root node
|
||||||
// We find the index of the next node by calculating BASE[s] ^ c where s is
|
// We find the index of the next node by calculating BASE[s] ^ c where s is
|
||||||
// the index of the previous node and c is a numerical character value
|
// the index of the previous node and c is a numerical character value
|
||||||
uint32_t node_index = 0;
|
uint32_t node_index = 0;
|
||||||
// get BASE of the root node
|
// get BASE of the root node
|
||||||
node_index = xcda_view.get_base(node_index);
|
node_index = xcda_view.get_base(node_index);
|
||||||
for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) {
|
for (size_t prefix_offset = input_offset; prefix_offset < input.size(); prefix_offset++) {
|
||||||
unsigned char c = input[prefix_offset];
|
unsigned char c = input[prefix_offset];
|
||||||
if (c == 0) {
|
if (c == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
node_index ^= c;
|
node_index ^= c;
|
||||||
// if value of LCHECK is not c it means that this is not a child of
|
// if value of LCHECK is not c it means that this is not a child of
|
||||||
// the previous node, so we stop matching
|
// the previous node, so we stop matching
|
||||||
if (xcda_view.get_lcheck(node_index) != c) {
|
if (xcda_view.get_lcheck(node_index) != c) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
bool is_leaf = xcda_view.get_leaf(node_index);
|
bool is_leaf = xcda_view.get_leaf(node_index);
|
||||||
// get BASE of the current node
|
// get BASE of the current node
|
||||||
node_index ^= xcda_view.get_base(node_index);
|
node_index ^= xcda_view.get_base(node_index);
|
||||||
// if LEAF of the current node is true, it means that its BASE points to the node
|
// if LEAF of the current node is true, it means that its BASE points to the node
|
||||||
// containing index of replacement sequence for currently matched input prefix
|
// containing index of replacement sequence for currently matched input prefix
|
||||||
if (is_leaf)
|
if (is_leaf)
|
||||||
{
|
{
|
||||||
longest_prefix_length = prefix_offset - input_offset + 1;
|
longest_prefix_length = prefix_offset - input_offset + 1;
|
||||||
// get index of replacement sequence for currently matched input prefix
|
// get index of replacement sequence for currently matched input prefix
|
||||||
longest_prefix_offset = xcda_view.get_value(node_index);
|
longest_prefix_offset = xcda_view.get_value(node_index);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14299,11 +14303,11 @@ private:
|
||||||
// escaped space symbol - U+2581 (Lower One Eighth Block)
|
// escaped space symbol - U+2581 (Lower One Eighth Block)
|
||||||
const std::string escaped_space = "\xE2\x96\x81";
|
const std::string escaped_space = "\xE2\x96\x81";
|
||||||
|
|
||||||
char * prefix_replacements;
|
char * prefix_replacements = NULL;
|
||||||
size_t prefix_replacements_size;
|
size_t prefix_replacements_size = 0;
|
||||||
|
|
||||||
uint32_t * xcda_array;
|
uint32_t * xcda_array = NULL;
|
||||||
size_t xcda_array_size;
|
size_t xcda_array_size = 0;
|
||||||
|
|
||||||
struct naive_trie user_defined_token_matcher;
|
struct naive_trie user_defined_token_matcher;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue