diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bc0d042ae..ce243680e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -719,14 +719,17 @@ struct server_task_result_embd : server_task_result { int index = 0; std::vector embedding; + int32_t n_tokens; + virtual int get_index() override { return index; } virtual json to_json() override { return json { - {"index", index}, - {"embedding", embedding}, + {"index", index}, + {"embedding", embedding}, + {"tokens_evaluated", n_tokens}, }; } }; @@ -1995,6 +1998,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; const int n_embd = llama_n_embd(model); diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index fc7c20064..fea1d6510 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -97,3 +97,33 @@ def test_same_prompt_give_same_result(): vi = res.body['data'][i]['embedding'] for x, y in zip(v0, vi): assert abs(x - y) < EPSILON + + +@pytest.mark.parametrize( + "content,n_tokens", + [ + ("I believe the meaning of life is", 7), + ("This is a test", 4), + ] +) +def test_embedding_usage_single(content, n_tokens): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={"input": content}) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == n_tokens + + +def test_embedding_usage_multiple(): + global server + server.start() + res = server.make_request("POST", "/embeddings", data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == 2 * 7 diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c6f08bf21..9d281ee92 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -560,6 +560,7 @@ static json oaicompat_completion_params_parse( static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { json data = json::array(); + int32_t n_tokens = 0; int i = 0; for (const auto & elem : embeddings) { data.push_back(json{ @@ -567,14 +568,16 @@ static json format_embeddings_response_oaicompat(const json & request, const jso {"index", i++}, {"object", "embedding"} }); + + n_tokens += json_value(elem, "tokens_evaluated", 0); } json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { // TODO: fill - {"prompt_tokens", 0}, - {"total_tokens", 0} + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} }}, {"data", data} };