diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8feff6702..94e06b3cc 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2003,6 +2003,10 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); + // track if this is an embedding or non-embedding batch + // -1: none, 0: non-embedding, 1: embedding + int32_t batch_type = -1; + // next, batch any pending prompts without exceeding n_batch if (params.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { @@ -2173,6 +2177,14 @@ struct server_context { } } + // check that we are in the right batch_type, if not defer the slot + bool slot_type = slot.embedding ? 1 : 0; + if (batch_type == -1) { + batch_type = slot_type; + } else if (batch_type != slot_type) { + continue; + } + // keep only the common part int p0 = (int) system_tokens.size() + slot.n_past; if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { @@ -2274,6 +2286,9 @@ struct server_context { {"n_tokens", batch.n_tokens}, }); + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, batch_type == 1); + // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); @@ -2988,6 +3003,11 @@ int main(int argc, char ** argv) { }; const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body); @@ -3083,6 +3103,11 @@ int main(int argc, char ** argv) { }; const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); @@ -3155,6 +3180,11 @@ int main(int argc, char ** argv) { }; const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { + if (ctx_server.params.embedding) { + res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return; + } + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body); @@ -3243,11 +3273,6 @@ int main(int argc, char ** argv) { const auto handle_embeddings = [¶ms, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - if (!params.embedding) { - res.status = 501; - res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8"); - return; - } const json body = json::parse(req.body); bool is_openai = false;