diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 436170a03..e6fcf49b5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3651,14 +3651,14 @@ int main(int argc, char ** argv) { const json body = json::parse(req.body); bool oaicompat = false; - // an input prompt can be a string or a list of tokens (integer) + // for the shape of input/content, see tokenize_input_prompts() json prompt; - if (body.count("input") != 0) { + if (body.contains("input")) { oaicompat = true; prompt = body.at("input"); - } else if (body.count("content") != 0) { - // with "content", we only support single prompt - prompt = std::vector{body.at("content")}; + } else if (body.contains("content")) { + oaicompat = false; + prompt = body.at("content"); } else { res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index fea1d6510..c66ee8a45 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -45,6 +45,31 @@ def test_embedding_multiple(): assert len(d['embedding']) > 1 +@pytest.mark.parametrize( + "content", + [ + # single prompt + "string", + [12, 34, 56], + [12, 34, "string", 56, 78], + # multiple prompts + ["string1", "string2"], + ["string1", [12, 34, 56]], + [[12, 34, 56], [12, 34, 56]], + [[12, 34, 56], [12, "string", 34, 56]], + ] +) +def test_embedding_mixed_input(content): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={"content": content}) + assert res.status_code == 200 + assert len(res.body['data']) == len(content) + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + def test_embedding_openai_library_single(): global server server.start() diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8fffe484a..ffdffe904 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -138,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_ * and multiple prompts (multi-tasks): * - "prompt": ["string1", "string2"] * - "prompt": ["string1", [12, 34, 56]] + * - "prompt": [[12, 34, 56], [78, 90, 12]] * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] */ static std::vector tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {