From a6013ded42f2ba2119f802070242c48be3264461 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sat, 18 Jan 2025 12:53:18 +0800 Subject: [PATCH] applied linting suggestions, updated to latest llama_vocab changes, added a safety check, added newline to guide token start --- examples/tts/tts.cpp | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index c73d0d94d..f78f76303 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -425,17 +425,19 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) { prompt_add(prompt, vocab, "<|im_start|>\n", true, true); } -static std::vector prepare_guide_tokens(const llama_model * model, const std::string& str) -{ +static std::vector prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) { const std::string& delimiter = "<|text_sep|>"; std::vector result; size_t start = 0; size_t end = str.find(delimiter); + //first token is always a newline, as it was not previously added + result.push_back(common_tokenize(vocab, "\n", false, true)[0]); + while (end != std::string::npos) { std::string current_word = str.substr(start, end - start); - auto tmp = common_tokenize(model, current_word, false, true); + auto tmp = common_tokenize(vocab, current_word, false, true); result.push_back(tmp[0]); start = end + delimiter.length(); end = str.find(delimiter, start); @@ -443,8 +445,10 @@ static std::vector prepare_guide_tokens(const llama_model * model, // Add the last part std::string current_word = str.substr(start); - auto tmp = common_tokenize(model, current_word, false, true); - result.push_back(tmp[0]); + auto tmp = common_tokenize(vocab, current_word, false, true); + if (tmp.size() > 0) { + result.push_back(tmp[0]); + } return result; } @@ -532,9 +536,8 @@ int main(int argc, char ** argv) { // convert the input text into the necessary format expected by OuteTTS { std::string prompt_clean = process_text(params.prompt); - if(params.vocoder.use_guide_tokens) - { - guide_tokens = prepare_guide_tokens(model_ttc,prompt_clean); + if (params.vocoder.use_guide_tokens) { + guide_tokens = prepare_guide_tokens(vocab, prompt_clean); } LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); @@ -761,8 +764,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]); //guide tokens help prevent hallucinations by forcing the TTS to use the correct word - if(!guide_tokens.empty() && next_token_uses_guide_token && !llama_token_is_control(model_ttc, new_token_id) && !llama_token_is_eog(model_ttc, new_token_id)) - { + if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) { llama_token guide_token = guide_tokens[0]; guide_tokens.erase(guide_tokens.begin()); new_token_id = guide_token; //ensure correct word fragment is used