From 38725ef6dafd46dd83efd7c141514fd43631e683 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krystian=20Chachu=C5=82a?= Date: Tue, 17 Dec 2024 13:04:02 +0100 Subject: [PATCH] server : add bad input handling in embeddings --- examples/server/server.cpp | 14 ++++++-- examples/server/tests/unit/test_embedding.py | 36 ++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bc0d042ae..db1c87ad8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3649,13 +3649,18 @@ int main(int argc, char ** argv) { oaicompat = true; prompt = body.at("input"); } else if (body.count("content") != 0) { - // with "content", we only support single prompt - prompt = std::vector{body.at("content")}; + prompt = body.at("content"); } else { res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } + // with "content", we only support single prompt + if (!oaicompat && prompt.type() != json::value_t::string) { + res_error(res, format_error_response("\"content\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + // create and queue the task json responses = json::array(); bool error = false; @@ -3663,6 +3668,11 @@ int main(int argc, char ** argv) { std::vector tasks; std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true); for (size_t i = 0; i < tokenized_prompts.size(); i++) { + if (tokenized_prompts[i].size() == 0) { + res_error(res, format_error_response("input cannot be an empty string", ERROR_TYPE_INVALID_REQUEST)); + return; + } + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); task.id = ctx_server.queue_tasks.get_new_id(); task.index = i; diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index fc7c20064..7ee4a74a2 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -97,3 +97,39 @@ def test_same_prompt_give_same_result(): vi = res.body['data'][i]['embedding'] for x, y in zip(v0, vi): assert abs(x - y) < EPSILON + + +@pytest.mark.parametrize("text", [ + None, + True, + "", + 42, + 4.2, + {}, + [], + [""], + ["This is a test", ""], +]) +def test_embedding_bad_input(text): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={"input": text}) + assert res.status_code >= 400 + + +@pytest.mark.parametrize("text", [ + None, + True, + "", + 42, + 4.2, + {}, + [], + [""], + ["This is a test"], +]) +def test_embedding_content_bad_input(text): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={"content": text}) + assert res.status_code >= 400