server : update /embeddings and /v1/embeddings endpoints
ggml-ci
This commit is contained in:
parent
2a94c33028
commit
abf33e2017
2 changed files with 59 additions and 28 deletions
|
@ -731,25 +731,31 @@ struct server_task_result_embd : server_task_result {
|
||||||
|
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
|
||||||
|
// OAI-compat fields
|
||||||
|
bool oaicompat = false;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
if (embedding.size() == 1) {
|
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
||||||
// to be OAI compatible
|
}
|
||||||
return json {
|
|
||||||
{"index", index},
|
|
||||||
{"embedding", embedding[0]},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
|
json to_json_non_oaicompat() {
|
||||||
return json {
|
return json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"embedding", embedding},
|
{"embedding", embedding},
|
||||||
{"tokens_evaluated", n_tokens},
|
{"tokens_evaluated", n_tokens},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json to_json_oaicompat() {
|
||||||
|
return json {
|
||||||
|
{"index", index},
|
||||||
|
{"embedding", embedding[0]},
|
||||||
|
};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_rerank : server_task_result {
|
struct server_task_result_rerank : server_task_result {
|
||||||
|
@ -2027,9 +2033,10 @@ struct server_context {
|
||||||
|
|
||||||
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
||||||
auto res = std::make_unique<server_task_result_embd>();
|
auto res = std::make_unique<server_task_result_embd>();
|
||||||
res->id = slot.id_task;
|
res->id = slot.id_task;
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
res->n_tokens = slot.n_prompt_tokens;
|
res->n_tokens = slot.n_prompt_tokens;
|
||||||
|
res->oaicompat = slot.params.oaicompat;
|
||||||
|
|
||||||
const int n_embd = llama_n_embd(model);
|
const int n_embd = llama_n_embd(model);
|
||||||
|
|
||||||
|
@ -3678,14 +3685,17 @@ int main(int argc, char ** argv) {
|
||||||
res_ok(res, data);
|
res_ok(res, data);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
bool oaicompat = false;
|
|
||||||
|
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||||
|
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// for the shape of input/content, see tokenize_input_prompts()
|
// for the shape of input/content, see tokenize_input_prompts()
|
||||||
json prompt;
|
json prompt;
|
||||||
if (body.contains("input")) {
|
if (body.count("input") != 0) {
|
||||||
oaicompat = true;
|
|
||||||
prompt = body.at("input");
|
prompt = body.at("input");
|
||||||
} else if (body.contains("content")) {
|
} else if (body.contains("content")) {
|
||||||
oaicompat = false;
|
oaicompat = false;
|
||||||
|
@ -3710,10 +3720,15 @@ int main(int argc, char ** argv) {
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||||
|
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||||
|
|
||||||
|
// OAI-compat
|
||||||
|
task.params.oaicompat = oaicompat;;
|
||||||
|
|
||||||
tasks.push_back(task);
|
tasks.push_back(task);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3741,12 +3756,18 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// write JSON response
|
// write JSON response
|
||||||
json root = oaicompat
|
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
|
||||||
? format_embeddings_response_oaicompat(body, responses)
|
|
||||||
: responses.size() == 1 ? responses[0] : json(responses);
|
|
||||||
res_ok(res, root);
|
res_ok(res, root);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
handle_embeddings_impl(req, res, false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
handle_embeddings_impl(req, res, true);
|
||||||
|
};
|
||||||
|
|
||||||
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||||
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
||||||
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
@ -3920,7 +3941,7 @@ int main(int argc, char ** argv) {
|
||||||
svr->Post("/infill", handle_infill);
|
svr->Post("/infill", handle_infill);
|
||||||
svr->Post("/embedding", handle_embeddings); // legacy
|
svr->Post("/embedding", handle_embeddings); // legacy
|
||||||
svr->Post("/embeddings", handle_embeddings);
|
svr->Post("/embeddings", handle_embeddings);
|
||||||
svr->Post("/v1/embeddings", handle_embeddings);
|
svr->Post("/v1/embeddings", handle_embeddings_oai);
|
||||||
svr->Post("/rerank", handle_rerank);
|
svr->Post("/rerank", handle_rerank);
|
||||||
svr->Post("/reranking", handle_rerank);
|
svr->Post("/reranking", handle_rerank);
|
||||||
svr->Post("/v1/rerank", handle_rerank);
|
svr->Post("/v1/rerank", handle_rerank);
|
||||||
|
|
|
@ -16,7 +16,7 @@ def test_embedding_single():
|
||||||
global server
|
global server
|
||||||
server.pooling = 'last'
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
"input": "I believe the meaning of life is",
|
"input": "I believe the meaning of life is",
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
|
@ -32,7 +32,7 @@ def test_embedding_multiple():
|
||||||
global server
|
global server
|
||||||
server.pooling = 'last'
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
"input": [
|
"input": [
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||||
|
@ -84,16 +84,26 @@ def test_embedding_pooling_none():
|
||||||
"input": "hello hello hello",
|
"input": "hello hello hello",
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert len(res.body['data']) == 1
|
assert 'embedding' in res.body[0]
|
||||||
assert 'embedding' in res.body['data'][0]
|
assert len(res.body[0]['embedding']) == 3
|
||||||
assert len(res.body['data'][0]['embedding']) == 3
|
|
||||||
|
|
||||||
|
def test_embedding_pooling_none_oai():
|
||||||
|
global server
|
||||||
|
server.pooling = 'none'
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
|
"input": "hello hello hello",
|
||||||
|
})
|
||||||
|
# /v1/embeddings does not support pooling type 'none'
|
||||||
|
assert res.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_openai_library_single():
|
def test_embedding_openai_library_single():
|
||||||
global server
|
global server
|
||||||
server.pooling = 'last'
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
|
||||||
assert len(res.data) == 1
|
assert len(res.data) == 1
|
||||||
assert len(res.data[0].embedding) > 1
|
assert len(res.data[0].embedding) > 1
|
||||||
|
@ -103,7 +113,7 @@ def test_embedding_openai_library_multiple():
|
||||||
global server
|
global server
|
||||||
server.pooling = 'last'
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
res = client.embeddings.create(model="text-embedding-3-small", input=[
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
"Write a joke about AI from a very long prompt which will not be truncated",
|
"Write a joke about AI from a very long prompt which will not be truncated",
|
||||||
|
@ -119,7 +129,7 @@ def test_embedding_error_prompt_too_long():
|
||||||
global server
|
global server
|
||||||
server.pooling = 'last'
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
"input": "This is a test " * 512,
|
"input": "This is a test " * 512,
|
||||||
})
|
})
|
||||||
assert res.status_code != 200
|
assert res.status_code != 200
|
||||||
|
@ -129,7 +139,7 @@ def test_embedding_error_prompt_too_long():
|
||||||
def test_same_prompt_give_same_result():
|
def test_same_prompt_give_same_result():
|
||||||
server.pooling = 'last'
|
server.pooling = 'last'
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/embeddings", data={
|
res = server.make_request("POST", "/v1/embeddings", data={
|
||||||
"input": [
|
"input": [
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
"I believe the meaning of life is",
|
"I believe the meaning of life is",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue