ChatON: tokenize keeping in mind the taggedMessage subparts

Initial go
This commit is contained in:
HanishKVC 2024-05-08 18:29:48 +05:30
parent 8dfa31bb91
commit b6da7d9c9d

View file

@ -628,6 +628,51 @@ inline int32_t chaton_tmpl_apply_ex_capi(
return taggedLength;
}
// Copied from common.cpp
std::vector<llama_token> chaton_llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_special,
bool parse_special) {
LOGLN("DBUG:%s:%s:special[add:%d, parse:%d]", __func__, text.c_str(), add_special, parse_special);
// upper limit for the number of tokens
int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
}
return result;
}
// Tokenize the passed taggedText, keeping in mind the subparts within and
// inturn whether to parse special tokens in them or not (partsTypes).
std::vector<llama_token> chaton_llama_tokenize_ex(
const llama_context *ctx,
const std::string &taggedText,
const std::string &partsTypes,
const std::vector<int32_t> &partsLengths,
bool addSpecial
) {
std::vector<llama_token> tokens;
int iPart = 0;
int iStart = 0;
for(auto partLen: partsLengths) {
auto partType = partsTypes[iPart];
iPart += 1;
auto msgPart = taggedText.substr(iStart, partLen);
iStart += partLen;
auto parseSpecial = partType == ChatParts::S ? true : false;
auto curTokens = chaton_llama_tokenize(llama_get_model(ctx), msgPart, addSpecial, parseSpecial);
tokens.insert(tokens.end(), curTokens.begin(), curTokens.end());
}
return tokens;
}
/**
* if tmpl is