Compare commits

...
Sign in to create a new pull request.

2 commits

Author SHA1 Message Date
LostRuins Concedo
90a0349349
recommended way to check if the version is 0.3, as requested by ngxson
recommended way to check if the version is 0.3, as requested by ngxson
2025-01-19 21:43:59 +08:00
Concedo
b5486956ff added rudimentary support for outetts v0.3 500m and 1b models 2025-01-18 18:48:49 +08:00

View file

@ -371,7 +371,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
}
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
static std::string process_text(const std::string & text) {
static std::string process_text(const std::string & text, bool is_version_0_3) {
// For now I skipped text romanization as I am unsure how to handle
// uroman and MeCab implementations in C++
@ -401,7 +401,7 @@ static std::string process_text(const std::string & text) {
if (c == ' ') {
prompt_clean += "<|text_sep|>";
*/
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), is_version_0_3?"<|space|>":"<|text_sep|>");
return processed_text;
}
@ -425,8 +425,7 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
}
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
const std::string& delimiter = "<|text_sep|>";
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const std::string& delimiter) {
std::vector<llama_token> result;
size_t start = 0;
@ -523,6 +522,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> codes;
std::vector<llama_token> guide_tokens;
//determine OuteTTS version and vocab code offset. v0.2 does not have <|space|>, but v0.3 does
const bool is_version_0_3 = (common_get_builtin_chat_template(model_ttc) == "outetts-0.3");
//determine the offset of the first audio code token
const int cts_offset = common_tokenize(vocab,"<|0|>",false,true)[0];
// process prompt and generate voice codes
{
LOG_INF("%s: constructing prompt ..\n", __func__);
@ -531,13 +535,17 @@ int main(int argc, char ** argv) {
prompt_init(prompt_inp, vocab);
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
if (is_version_0_3) {
prompt_add(prompt_inp, vocab, "<|text_start|>the<|space|>overall<|space|>package<|space|>from<|space|>just<|space|>two<|space|>people<|space|>is<|space|>pretty<|space|>remarkable<|space|>sure<|space|>i<|space|>have<|space|>some<|space|>critiques<|space|>about<|space|>some<|space|>of<|space|>the<|space|>gameplay<|space|>aspects<|space|>but<|space|>its<|space|>still<|space|>really<|space|>enjoyable<|space|>and<|space|>it<|space|>looks<|space|>lovely<|space|>", false, true);
} else {
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
}
// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt);
std::string prompt_clean = process_text(params.prompt, is_version_0_3);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, is_version_0_3?"<|space|>":"<|text_sep|>");
}
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
@ -549,8 +557,8 @@ int main(int argc, char ** argv) {
// disabled to save time on tokenizing each time
// TODO: load voices from the json files
#if 0
const std::string voice_data = R"(<|audio_start|>
#if 1
std::string voice_data = R"(<|audio_start|>
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
@ -582,12 +590,19 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n");
for (int i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]);
if (is_version_0_3)
{
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_start\|>)"), "");
voice_data = std::regex_replace(voice_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
}
printf("\n\n");
prompt_add(prompt_inp, vocab, voice_data, false, true);
// printf("\n\n");
// for (int i = 0; i < tmp.size(); ++i) {
// printf("%d, ", tmp[i]);
// }
// printf("\n\n");
#else
prompt_add(prompt_inp, llama_tokens {
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
@ -882,7 +897,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
}
// remove all non-audio tokens (i.e. < 151672 || > 155772)
codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end());
codes.erase(std::remove_if(codes.begin(), codes.end(), [cts_offset](llama_token t) { return t < cts_offset || t > (cts_offset+4100); }), codes.end());
{
const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
@ -891,7 +906,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
}
for (auto & token : codes) {
token -= 151672;
token -= cts_offset;
}
const auto t_voc_start = ggml_time_us();