Rename token attributes

This commit is contained in:
jaime-m-p 2024-06-04 00:56:22 +02:00
parent ac40ff0e50
commit 18f5fc766b
2 changed files with 53 additions and 54 deletions

View file

@ -2147,15 +2147,14 @@ struct llama_control_vector {
}; };
struct llama_vocab { struct llama_vocab {
using id = int32_t; using id = int32_t;
using token = std::string; using token = std::string;
using ttype = llama_token_type; using tattr = llama_token_attr;
using tattribs = llama_token_attribs;
struct token_data { struct token_data {
token text; token text;
float score; float score;
tattribs attribs; tattr attr;
}; };
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
@ -4739,20 +4738,20 @@ static void llm_load_vocab(
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
auto & token_data = vocab.id_to_token[i]; auto & token_data = vocab.id_to_token[i];
token_data.text = std::move(word); token_data.text = std::move(word);
token_data.score = scores ? scores[i] : 0.0f; token_data.score = scores ? scores[i] : 0.0f;
token_data.attribs = LLAMA_TOKEN_ATTRIB_NORMAL; token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;
if (toktypes) { //TODO: remove, required until per token attribs are available from GGUF file if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file
switch(toktypes[i]) { switch(toktypes[i]) {
case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attribs = LLAMA_TOKEN_ATTRIB_UNKNOWN; break; case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break;
case LLAMA_TOKEN_TYPE_UNUSED: token_data.attribs = LLAMA_TOKEN_ATTRIB_UNUSED; break; case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break;
case LLAMA_TOKEN_TYPE_NORMAL: token_data.attribs = LLAMA_TOKEN_ATTRIB_NORMAL; break; case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break;
case LLAMA_TOKEN_TYPE_CONTROL: token_data.attribs = LLAMA_TOKEN_ATTRIB_CONTROL; break; case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break;
case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attribs = LLAMA_TOKEN_ATTRIB_USER_DEFINED; break; case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
case LLAMA_TOKEN_TYPE_BYTE: token_data.attribs = LLAMA_TOKEN_ATTRIB_BYTE; break; case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break;
case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attribs = LLAMA_TOKEN_ATTRIB_UNDEFINED; break; case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break;
default: token_data.attribs = LLAMA_TOKEN_ATTRIB_UNDEFINED; break; default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break;
} }
} }
} }
@ -4845,7 +4844,7 @@ static void llm_load_vocab(
// build special tokens cache // build special tokens cache
{ {
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
if (!(vocab.id_to_token[id].attribs & LLAMA_TOKEN_ATTRIB_NORMAL)) { if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
vocab.cache_special_tokens.push_back(id); vocab.cache_special_tokens.push_back(id);
} }
} }
@ -4883,7 +4882,7 @@ static void llm_load_vocab(
// Handle per token attributes // Handle per token attributes
//NOTE: Each model customizes per token attributes. //NOTE: Each model customizes per token attributes.
//NOTE: Per token attributes are missing from the GGUF file. //NOTE: Per token attributes are missing from the GGUF file.
//TODO: Extract attribs from GGUF file. //TODO: Extract attributes from GGUF file.
{ {
auto _contains_any = [] (const std::string &str, const std::vector<std::string> &substrs) -> bool { auto _contains_any = [] (const std::string &str, const std::vector<std::string> &substrs) -> bool {
for (auto substr : substrs) { for (auto substr : substrs) {
@ -4894,14 +4893,14 @@ static void llm_load_vocab(
return false; return false;
}; };
auto _set_tokenid_attrib = [&] (const llama_vocab::id id, llama_token_attribs attrib, bool value) { auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
uint32_t attribs = vocab.id_to_token.at(id).attribs; uint32_t current = vocab.id_to_token.at(id).attr;
attribs = value ? (attribs | attrib) : (attribs & ~attrib); current = value ? (current | attr) : (current & ~attr);
vocab.id_to_token[id].attribs = (llama_token_attribs) attribs; vocab.id_to_token[id].attr = (llama_token_attr) current;
}; };
auto _set_token_attrib = [&] (const std::string & token, llama_token_attribs attrib, bool value) { auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
_set_tokenid_attrib(vocab.token_to_id.at(token), attrib, value); _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
}; };
std::string model_name; std::string model_name;
@ -4919,16 +4918,16 @@ static void llm_load_vocab(
// set attributes by model/tokenizer name // set attributes by model/tokenizer name
if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) { if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
_set_token_attrib("<mask>", LLAMA_TOKEN_ATTRIB_LSTRIP, true); _set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
} else if (_contains_any(model_name, {"phi-3", "phi3"})) { } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
for (auto id : vocab.cache_special_tokens) { for (auto id : vocab.cache_special_tokens) {
_set_tokenid_attrib(id, LLAMA_TOKEN_ATTRIB_RSTRIP, true); _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
} }
for (auto token : {"</s>"}) { for (auto token : {"</s>"}) {
_set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true); _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
} }
for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) { for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
_set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, false); _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
} }
} }
} }
@ -12683,27 +12682,27 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].attribs & LLAMA_TOKEN_ATTRIB_NORMAL; return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
} }
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].attribs & LLAMA_TOKEN_ATTRIB_UNKNOWN; return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
} }
static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].attribs & LLAMA_TOKEN_ATTRIB_CONTROL; return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
} }
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].attribs & LLAMA_TOKEN_ATTRIB_BYTE; return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
} }
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].attribs & LLAMA_TOKEN_ATTRIB_USER_DEFINED; return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
} }
static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
@ -13361,7 +13360,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
const int64_t left_reminder_offset = raw_text_base_offset + 0; const int64_t left_reminder_offset = raw_text_base_offset + 0;
int64_t left_reminder_length = match - raw_text_base_offset; int64_t left_reminder_length = match - raw_text_base_offset;
if (data.attribs & LLAMA_TOKEN_ATTRIB_LSTRIP) { if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) { while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
left_reminder_length--; left_reminder_length--;
} }
@ -13386,7 +13385,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
int64_t right_reminder_offset = match + special_token.length(); int64_t right_reminder_offset = match + special_token.length();
int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
if (data.attribs & LLAMA_TOKEN_ATTRIB_RSTRIP) { if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) { while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
right_reminder_offset++; right_reminder_offset++;
right_reminder_length--; right_reminder_length--;
@ -18276,9 +18275,9 @@ float llama_token_get_score(const struct llama_model * model, llama_token token)
return model->vocab.id_to_token[token].score; return model->vocab.id_to_token[token].score;
} }
llama_token_attribs llama_token_get_attribs(const struct llama_model * model, llama_token token) { llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
return model->vocab.id_to_token[token].attribs; return model->vocab.id_to_token[token].attr;
} }
bool llama_token_is_eog(const struct llama_model * model, llama_token token) { bool llama_token_is_eog(const struct llama_model * model, llama_token token) {

28
llama.h
View file

@ -97,7 +97,7 @@ extern "C" {
LLAMA_ROPE_TYPE_GLM = 4, LLAMA_ROPE_TYPE_GLM = 4,
}; };
enum llama_token_type { //TODO: remove, required until per token attribs are available from GGUF file enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
LLAMA_TOKEN_TYPE_UNDEFINED = 0, LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1, LLAMA_TOKEN_TYPE_NORMAL = 1,
LLAMA_TOKEN_TYPE_UNKNOWN = 2, LLAMA_TOKEN_TYPE_UNKNOWN = 2,
@ -107,18 +107,18 @@ extern "C" {
LLAMA_TOKEN_TYPE_BYTE = 6, LLAMA_TOKEN_TYPE_BYTE = 6,
}; };
enum llama_token_attribs { enum llama_token_attr {
LLAMA_TOKEN_ATTRIB_UNDEFINED = 0, LLAMA_TOKEN_ATTR_UNDEFINED = 0,
LLAMA_TOKEN_ATTRIB_UNKNOWN = 1 << 1, LLAMA_TOKEN_ATTR_UNKNOWN = 1 << 1,
LLAMA_TOKEN_ATTRIB_UNUSED = 1 << 2, LLAMA_TOKEN_ATTR_UNUSED = 1 << 2,
LLAMA_TOKEN_ATTRIB_NORMAL = 1 << 3, LLAMA_TOKEN_ATTR_NORMAL = 1 << 3,
LLAMA_TOKEN_ATTRIB_CONTROL = 1 << 4, // SPECIAL? LLAMA_TOKEN_ATTR_CONTROL = 1 << 4, // SPECIAL?
LLAMA_TOKEN_ATTRIB_USER_DEFINED = 1 << 5, LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 5,
LLAMA_TOKEN_ATTRIB_BYTE = 1 << 6, LLAMA_TOKEN_ATTR_BYTE = 1 << 6,
LLAMA_TOKEN_ATTRIB_NORMALIZED = 1 << 7, LLAMA_TOKEN_ATTR_NORMALIZED = 1 << 7,
LLAMA_TOKEN_ATTRIB_LSTRIP = 1 << 8, LLAMA_TOKEN_ATTR_LSTRIP = 1 << 8,
LLAMA_TOKEN_ATTRIB_RSTRIP = 1 << 9, LLAMA_TOKEN_ATTR_RSTRIP = 1 << 9,
LLAMA_TOKEN_ATTRIB_SINGLE_WORD = 1 << 10, LLAMA_TOKEN_ATTR_SINGLE_WORD = 1 << 10,
}; };
// model file types // model file types
@ -835,7 +835,7 @@ extern "C" {
LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
LLAMA_API enum llama_token_attribs llama_token_get_attribs(const struct llama_model * model, llama_token token); LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);