applied linting suggestions, updated to latest llama_vocab changes, added a safety check, added newline to guide token start

This commit is contained in:
Concedo 2025-01-18 12:53:18 +08:00
parent 9fa30422dc
commit a6013ded42

View file

@ -425,17 +425,19 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true); prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
} }
static std::vector<llama_token> prepare_guide_tokens(const llama_model * model, const std::string& str) static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
{
const std::string& delimiter = "<|text_sep|>"; const std::string& delimiter = "<|text_sep|>";
std::vector<llama_token> result; std::vector<llama_token> result;
size_t start = 0; size_t start = 0;
size_t end = str.find(delimiter); size_t end = str.find(delimiter);
//first token is always a newline, as it was not previously added
result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
while (end != std::string::npos) { while (end != std::string::npos) {
std::string current_word = str.substr(start, end - start); std::string current_word = str.substr(start, end - start);
auto tmp = common_tokenize(model, current_word, false, true); auto tmp = common_tokenize(vocab, current_word, false, true);
result.push_back(tmp[0]); result.push_back(tmp[0]);
start = end + delimiter.length(); start = end + delimiter.length();
end = str.find(delimiter, start); end = str.find(delimiter, start);
@ -443,8 +445,10 @@ static std::vector<llama_token> prepare_guide_tokens(const llama_model * model,
// Add the last part // Add the last part
std::string current_word = str.substr(start); std::string current_word = str.substr(start);
auto tmp = common_tokenize(model, current_word, false, true); auto tmp = common_tokenize(vocab, current_word, false, true);
result.push_back(tmp[0]); if (tmp.size() > 0) {
result.push_back(tmp[0]);
}
return result; return result;
} }
@ -532,9 +536,8 @@ int main(int argc, char ** argv) {
// convert the input text into the necessary format expected by OuteTTS // convert the input text into the necessary format expected by OuteTTS
{ {
std::string prompt_clean = process_text(params.prompt); std::string prompt_clean = process_text(params.prompt);
if(params.vocoder.use_guide_tokens) if (params.vocoder.use_guide_tokens) {
{ guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean);
} }
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
@ -761,8 +764,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]); llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word //guide tokens help prevent hallucinations by forcing the TTS to use the correct word
if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id)) if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
{
llama_token guide_token = guide_tokens[0]; llama_token guide_token = guide_tokens[0];
guide_tokens.erase(guide_tokens.begin()); guide_tokens.erase(guide_tokens.begin());
new_token_id = guide_token; //ensure correct word fragment is used new_token_id = guide_token; //ensure correct word fragment is used