Refactor + add 'jina-v2' for testing 'lstrip'
This commit is contained in:
parent
ada961cec2
commit
01c9229186
2 changed files with 44 additions and 35 deletions
77
llama.cpp
77
llama.cpp
|
@ -4872,9 +4872,29 @@ static void llm_load_vocab(
|
||||||
//NOTE: Each model customizes per token attributes.
|
//NOTE: Each model customizes per token attributes.
|
||||||
//NOTE: Per token attributes are missing from the GGUF file.
|
//NOTE: Per token attributes are missing from the GGUF file.
|
||||||
//TODO: Merge llama_token_type and llama_token_attrib.
|
//TODO: Merge llama_token_type and llama_token_attrib.
|
||||||
|
//TODO: Extract attribs from GGUF file.
|
||||||
{
|
{
|
||||||
|
auto _contains_any = [] (const std::string &str, const std::vector<std::string> &substrs) -> bool {
|
||||||
|
for (auto substr : substrs) {
|
||||||
|
if (str.find(substr) < std::string::npos) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto _set_tokenid_attrib = [&] (const llama_vocab::id id, llama_token_attrib attrib, bool value) {
|
||||||
|
uint32_t attribs = vocab.id_to_token.at(id).attribs;
|
||||||
|
attribs = value ? (attribs | attrib) : (attribs & ~attrib);
|
||||||
|
vocab.id_to_token[id].attribs = (llama_token_attrib) attribs;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto _set_token_attrib = [&] (const std::string & token, llama_token_attrib attrib, bool value) {
|
||||||
|
_set_tokenid_attrib(vocab.token_to_id.at(token), attrib, value);
|
||||||
|
};
|
||||||
|
|
||||||
// convert token type as an attribute
|
// convert token type as an attribute
|
||||||
for (auto data : vocab.id_to_token) {
|
for (auto &data : vocab.id_to_token) {
|
||||||
uint32_t attrib = LLAMA_TOKEN_ATTRIB_UNDEFINED;
|
uint32_t attrib = LLAMA_TOKEN_ATTRIB_UNDEFINED;
|
||||||
attrib |= LLAMA_TOKEN_ATTRIB_UNKNOWN * (data.type == LLAMA_TOKEN_TYPE_UNKNOWN);
|
attrib |= LLAMA_TOKEN_ATTRIB_UNKNOWN * (data.type == LLAMA_TOKEN_TYPE_UNKNOWN);
|
||||||
attrib |= LLAMA_TOKEN_ATTRIB_UNUSED * (data.type == LLAMA_TOKEN_TYPE_UNUSED);
|
attrib |= LLAMA_TOKEN_ATTRIB_UNUSED * (data.type == LLAMA_TOKEN_TYPE_UNUSED);
|
||||||
|
@ -4885,44 +4905,31 @@ static void llm_load_vocab(
|
||||||
data.attribs = (llama_token_attrib) attrib;
|
data.attribs = (llama_token_attrib) attrib;
|
||||||
}
|
}
|
||||||
|
|
||||||
// set attributes by model name
|
|
||||||
std::string model_name;
|
std::string model_name;
|
||||||
if (ml.get_key(LLM_KV_GENERAL_NAME, model_name, false)) {
|
std::string tokenizer_pre;
|
||||||
std::transform(model_name.begin(), model_name.end(), model_name.begin(),
|
|
||||||
[] (const std::string::value_type x) {
|
|
||||||
return std::tolower(x);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
auto _contains_any = [&model_name] (const std::vector<std::string> &substrs) -> bool {
|
ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
|
||||||
for (auto substr : substrs) {
|
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
|
||||||
if (model_name.find(substr) < std::string::npos) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto _set_tokenid_attrib = [&] (const llama_vocab::id id, llama_token_attrib attrib, bool value) {
|
// model name to lowercase
|
||||||
uint32_t attribs = vocab.id_to_token[id].attribs;
|
std::transform(model_name.begin(), model_name.end(), model_name.begin(),
|
||||||
attribs = value ? (attribs | attrib) : (attribs & ~attrib);
|
[] (const std::string::value_type x) {
|
||||||
vocab.id_to_token[id].attribs = (llama_token_attrib) attribs;
|
return std::tolower(x);
|
||||||
};
|
}
|
||||||
|
);
|
||||||
|
|
||||||
auto _set_token_attrib = [&] (const std::string & token, llama_token_attrib attrib, bool value) {
|
// set attributes by model/tokenizer name
|
||||||
_set_tokenid_attrib(vocab.token_to_id.at(token), attrib, value);
|
if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
|
||||||
};
|
_set_token_attrib("<mask>", LLAMA_TOKEN_ATTRIB_LSTRIP, true);
|
||||||
|
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
||||||
if (_contains_any({"phi-3", "phi3"})) {
|
for (auto id : vocab.cache_special_tokens) {
|
||||||
for (auto id : vocab.cache_special_tokens) {
|
_set_tokenid_attrib(id, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
|
||||||
_set_tokenid_attrib(id, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
|
}
|
||||||
}
|
for (auto token : {"</s>"}) {
|
||||||
for (auto token : {"</s>"}) {
|
_set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
|
||||||
_set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
|
}
|
||||||
}
|
for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
|
||||||
for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
|
_set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, false);
|
||||||
_set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, false);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -156,6 +156,8 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
|
||||||
'<s>a', # Phi-3 fail
|
'<s>a', # Phi-3 fail
|
||||||
'<unk><|endoftext|><s>', # Phi-3 fail
|
'<unk><|endoftext|><s>', # Phi-3 fail
|
||||||
'a\na', # TODO: Bert fail
|
'a\na', # TODO: Bert fail
|
||||||
|
'a </s> b', # rstrip phi-3
|
||||||
|
'a <mask> b', # lstrip jina-v2
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue