diff --git a/common/arg.cpp b/common/arg.cpp
index 27886b84e..7f5c8287a 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -2224,8 +2224,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.hf_file = "OuteTTS-0.2-500M-Q8_0.gguf";
params.vocoder.hf_repo = "ggml-org/WavTokenizer";
params.vocoder.hf_file = "WavTokenizer-Large-75-F16.gguf";
+ params.ctx_shift = false; // for better results
}
- ).set_examples({LLAMA_EXAMPLE_TTS}));
+ ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
return ctx_arg;
}
diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt
index 1b7cc8c13..6e84abd1b 100644
--- a/examples/server/CMakeLists.txt
+++ b/examples/server/CMakeLists.txt
@@ -13,6 +13,7 @@ set(TARGET_SRCS
server.cpp
utils.hpp
httplib.h
+ ../tts/tts-impl.cpp
)
set(PUBLIC_ASSETS
index.html.gz
diff --git a/examples/server/public_tts/index.html b/examples/server/public_tts/index.html
new file mode 100644
index 000000000..a7a8d3784
--- /dev/null
+++ b/examples/server/public_tts/index.html
@@ -0,0 +1,132 @@
+
+
+
+
+
+ llama.cpp TTS
+
+
+
+
+ llama.cpp TTS
+
+ Input text:
+
+
+
+ Status: ready
+
+
+
+
+
+
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 127323e77..732067d3a 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -7,6 +7,7 @@
#include "log.h"
#include "sampling.h"
#include "speculative.h"
+#include "../tts/tts-impl.hpp"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
@@ -65,6 +66,7 @@ enum server_task_type {
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_SET_LORA,
+ SERVER_TASK_TYPE_TTS_EMBD,
};
enum oaicompat_type {
@@ -551,12 +553,12 @@ struct server_task_result_cmpl_final : server_task_result {
bool post_sampling_probs;
std::vector probs_output;
- std::vector response_fields;
+ std::vector response_fields;
slot_params generation_params;
// OAI-compat fields
- bool verbose = false;
+ bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
@@ -937,6 +939,20 @@ struct server_task_result_embd : server_task_result {
}
};
+struct server_task_result_tts_embd : server_task_result {
+ int index = 0;
+ std::vector embd;
+ double t_ms = 0.0;
+
+ virtual int get_index() override {
+ return index; // unused
+ }
+
+ virtual json to_json() override {
+ return json {}; // unused
+ }
+};
+
struct server_task_result_rerank : server_task_result {
int index = 0;
float score = -1e6;
@@ -1629,6 +1645,7 @@ struct server_context {
// note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result llama_init;
common_init_result llama_init_dft;
+ common_init_result llama_init_vocoder;
llama_model * model = nullptr;
llama_context * ctx = nullptr;
@@ -1731,6 +1748,20 @@ struct server_context {
cparams_dft.type_v = GGML_TYPE_F16;
}
+ if (!params.vocoder.model.empty()) {
+ common_params v_params = params_base;
+ v_params.model = params.vocoder.model;
+ v_params.model_url = params.vocoder.model_url;
+ v_params.hf_repo = params.vocoder.hf_repo;
+ v_params.hf_file = params.vocoder.hf_file;
+ v_params.embedding = true;
+ v_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
+ // make sure the vocoder has the sufficient batch size
+ v_params.n_batch = v_params.n_ctx;
+ v_params.n_ubatch = v_params.n_ctx;
+ llama_init_vocoder = common_init_from_params(v_params);
+ }
+
return true;
}
@@ -2578,6 +2609,34 @@ struct server_context {
res->id = task.id;
queue_results.send(std::move(res));
} break;
+ case SERVER_TASK_TYPE_TTS_EMBD:
+ {
+ const auto ctx_cts = llama_init_vocoder.context.get();
+ const int n_ubatch = llama_n_ubatch(ctx_cts);
+ const int n_codes = (int) task.prompt_tokens.size();
+ if (n_codes > n_ubatch) {
+ send_error(task, string_format("Number of codes (%d) exceeds the maximum ubatch of vocoder model (%d)", n_codes, n_ubatch), ERROR_TYPE_INVALID_REQUEST);
+ break;
+ }
+
+ std::vector embd;
+ uint64_t t_start = ggml_time_us();
+ SRV_DBG("tts_get_embd with %d codes", n_codes);
+ int status = tts_get_embd(ctx_cts, task.prompt_tokens, embd);
+ if (status != 0) {
+ send_error(task, string_format("Failed to get TTS embedding, status code = %d", status), ERROR_TYPE_SERVER);
+ break;
+ }
+ if (embd.size() == 0) {
+ send_error(task, "no embeddings is returned from tts_get_embd()", ERROR_TYPE_SERVER);
+ break;
+ }
+ auto res = std::make_unique();
+ res->id = task.id;
+ res->embd = std::move(embd);
+ res->t_ms = (ggml_time_us() - t_start) / 1e3;
+ queue_results.send(std::move(res));
+ } break;
}
}
@@ -3148,7 +3207,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
LOG_DBG("request: %s\n", req.body.c_str());
- LOG_DBG("response: %s\n", res.body.c_str());
+ // exclude TTS endpoint, because response is raw WAV data
+ if (req.path != "/v1/audio/speech") {
+ LOG_DBG("response: %s\n", res.body.c_str());
+ }
}
std::function shutdown_handler;
@@ -4076,6 +4138,152 @@ int main(int argc, char ** argv) {
res_ok(res, root);
};
+ // TODO: this is POC, not optimized for performance
+ const auto handle_speech = [&](const httplib::Request & req, httplib::Response & res) {
+ if (ctx_server.llama_init_vocoder.context.get() == nullptr) {
+ res_error(res, format_error_response("This server does not support TTS. Start it with `--model-vocoder`", ERROR_TYPE_NOT_SUPPORTED));
+ return;
+ }
+
+ const json body = json::parse(req.body);
+
+ // ignore "model" and "voice" for now
+ const std::string input = json_value(body, "input", std::string());
+ const std::string response_format = json_value(body, "response_format", std::string("wav"));
+ const float speed = json_value(body, "speed", 1.0f);
+ if (input.empty()) {
+ res_error(res, format_error_response("\"input\" must be a non-empty string", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ if (response_format != "wav") {
+ res_error(res, format_error_response("\"response_format\" must be \"wav\"", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+ if (speed != 1.0f) {
+ res_error(res, format_error_response("\"speed\" must be 1.0", ERROR_TYPE_INVALID_REQUEST));
+ return;
+ }
+
+ llama_tokens codes;
+ result_timings ttc_timings;
+ // convert text to codes
+ {
+ server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.index = 0;
+ task.prompt_tokens = tts_preprocess_prompt(ctx_server.model, input);
+
+ task.params.stream = false;
+ task.params.return_tokens = true;
+ task.params.sampling.temp = 0.0;
+ task.params.sampling.top_k = 1;
+
+ ctx_server.queue_results.add_waiting_tasks({task});
+ ctx_server.queue_tasks.post(task);
+
+ // get the result
+ const server_task_result_ptr raw_result = ctx_server.queue_results.recv(task.id);
+ if (raw_result->is_error()) {
+ res_error(res, raw_result->to_json());
+ return;
+ }
+ const server_task_result_cmpl_final * result = dynamic_cast(raw_result.get());
+ GGML_ASSERT(result != nullptr);
+ GGML_ASSERT(!result->tokens.empty());
+ codes = std::move(result->tokens);
+
+ // debug
+ // SRV_DBG("codes str (before filter) = %s\n", common_detokenize(ctx_server.ctx, codes, true).c_str());
+
+ // post-process codes
+ // remove all non-audio tokens (i.e. < 151672 || > 155772)
+ codes.erase(std::remove_if(
+ codes.begin(),
+ codes.end(),
+ [](llama_token t) { return t < 151672 || t > 155772; }),
+ codes.end());
+ SRV_DBG("codes size = %d\n", (int) codes.size());
+
+ ttc_timings = std::move(result->timings);
+ }
+
+ // debug
+ // SRV_DBG("codes str = %s\n", common_detokenize(ctx_server.ctx, codes, true).c_str());
+
+ // convert codes to embeddings
+ int n_embd = llama_n_embd(ctx_server.llama_init_vocoder.model.get());
+ int n_codes = -1;
+ double t_voc_ms = 0.0;
+ std::vector embd;
+ {
+ server_task task = server_task(SERVER_TASK_TYPE_TTS_EMBD);
+ task.id = ctx_server.queue_tasks.get_new_id();
+ task.prompt_tokens = std::move(codes);
+
+ ctx_server.queue_results.add_waiting_tasks({task});
+ ctx_server.queue_tasks.post(task);
+
+ // get the result
+ const server_task_result_ptr raw_result = ctx_server.queue_results.recv(task.id);
+ if (raw_result->is_error()) {
+ res_error(res, raw_result->to_json());
+ return;
+ }
+ const server_task_result_tts_embd * result = dynamic_cast(raw_result.get());
+ GGML_ASSERT(result != nullptr);
+ GGML_ASSERT(!result->embd.empty());
+
+ // flatten the array
+ n_codes = result->embd.size() / n_embd;
+ embd = std::move(result->embd);
+ t_voc_ms = result->t_ms;
+ SRV_DBG("tts embd n_code = %d\n", n_codes);
+ SRV_DBG("tts embd size = %zu\n", embd.size());
+ SRV_DBG("tts embd t_voc_ms = %lf\n", t_voc_ms);
+ GGML_ASSERT(n_codes > 0);
+ }
+
+ // convert embeddings to wav
+ // will be freed by chunked_content_provider
+ const auto t_spec_start = ggml_time_us();
+ std::vector audio = tts_embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads);
+ double t_spec_ms = (ggml_time_us() - t_spec_start) / 1e3;
+
+ // for now, we can only leave timings in response headers, mostly for debugging
+ res.set_header("X-timings-ttc", ttc_timings.to_json().dump());
+ res.set_header("X-timings-voc", (json{{ "t_voc_ms", t_voc_ms }}).dump());
+ res.set_header("X-timings-spec", (json{{ "t_spec_ms", t_spec_ms }}).dump());
+
+ const auto chunked_content_provider = [audio = std::move(audio)](size_t, httplib::DataSink & sink) mutable {
+ // TODO: some how reuse save_wav16 instead of duplicating the code here
+ const int n_sr = 24000; // sampling rate
+ // zero out first 0.25 seconds
+ for (int i = 0; i < 24000/4; ++i) {
+ audio[i] = 0.0f;
+ }
+
+ wav_header header;
+ header.sample_rate = n_sr;
+ header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8);
+ header.block_align = header.num_channels * (header.bits_per_sample / 8);
+ header.data_size = audio.size() * (header.bits_per_sample / 8);
+ header.chunk_size = 36 + header.data_size;
+
+ sink.write(reinterpret_cast(&header), sizeof(header));
+
+ for (const auto & sample : audio) {
+ int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0));
+ sink.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample));
+ }
+ sink.done();
+ return false;
+ };
+
+ // https://mimetype.io/audio/vnd.wav
+ res.set_chunked_content_provider("audio/vnd.wav", chunked_content_provider);
+ res.status = 200;
+ };
+
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array();
const auto & loras = ctx_server.params_base.lora_adapters;
@@ -4166,6 +4374,7 @@ int main(int argc, char ** argv) {
svr->Post("/v1/reranking", handle_rerank);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
+ svr->Post("/v1/audio/speech", handle_speech);
// LoRA adapters hotswap
svr->Get ("/lora-adapters", handle_lora_adapters_list);
svr->Post("/lora-adapters", handle_lora_adapters_apply);
diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt
index c72bd814c..7dcb43292 100644
--- a/examples/tts/CMakeLists.txt
+++ b/examples/tts/CMakeLists.txt
@@ -1,5 +1,5 @@
set(TARGET llama-tts)
-add_executable(${TARGET} tts.cpp)
+add_executable(${TARGET} tts-impl.hpp tts-impl.cpp tts.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/examples/tts/tts-impl.cpp b/examples/tts/tts-impl.cpp
new file mode 100644
index 000000000..49377a7bc
--- /dev/null
+++ b/examples/tts/tts-impl.cpp
@@ -0,0 +1,540 @@
+#include "log.h"
+#include "llama.h"
+#include "common.h"
+#include "tts-impl.hpp"
+
+#define _USE_MATH_DEFINES // For M_PI on MSVC
+
+#include
+#include
+#include
+#include