server : add bad input handling in embeddings
This commit is contained in:
parent
4f51968aca
commit
38725ef6da
2 changed files with 48 additions and 2 deletions
|
@ -3649,13 +3649,18 @@ int main(int argc, char ** argv) {
|
|||
oaicompat = true;
|
||||
prompt = body.at("input");
|
||||
} else if (body.count("content") != 0) {
|
||||
// with "content", we only support single prompt
|
||||
prompt = std::vector<std::string>{body.at("content")};
|
||||
prompt = body.at("content");
|
||||
} else {
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// with "content", we only support single prompt
|
||||
if (!oaicompat && prompt.type() != json::value_t::string) {
|
||||
res_error(res, format_error_response("\"content\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
|
@ -3663,6 +3668,11 @@ int main(int argc, char ** argv) {
|
|||
std::vector<server_task> tasks;
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
if (tokenized_prompts[i].size() == 0) {
|
||||
res_error(res, format_error_response("input cannot be an empty string", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
|
|
|
@ -97,3 +97,39 @@ def test_same_prompt_give_same_result():
|
|||
vi = res.body['data'][i]['embedding']
|
||||
for x, y in zip(v0, vi):
|
||||
assert abs(x - y) < EPSILON
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", [
|
||||
None,
|
||||
True,
|
||||
"",
|
||||
42,
|
||||
4.2,
|
||||
{},
|
||||
[],
|
||||
[""],
|
||||
["This is a test", ""],
|
||||
])
|
||||
def test_embedding_bad_input(text):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={"input": text})
|
||||
assert res.status_code >= 400
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", [
|
||||
None,
|
||||
True,
|
||||
"",
|
||||
42,
|
||||
4.2,
|
||||
{},
|
||||
[],
|
||||
[""],
|
||||
["This is a test"],
|
||||
])
|
||||
def test_embedding_content_bad_input(text):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={"content": text})
|
||||
assert res.status_code >= 400
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue