From 88cc9719c449a8b6ef7669f7be6eebd8eb5cafa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krystian=20Chachu=C5=82a?= Date: Mon, 16 Dec 2024 14:45:06 +0100 Subject: [PATCH] server : fill usage info in reranking response --- examples/server/server.cpp | 8 ++++++-- examples/server/tests/unit/test_rerank.py | 23 +++++++++++++++++++++++ examples/server/utils.hpp | 9 ++++++--- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ce243680e..436170a03 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -738,14 +738,17 @@ struct server_task_result_rerank : server_task_result { int index = 0; float score = -1e6; + int32_t n_tokens; + virtual int get_index() override { return index; } virtual json to_json() override { return json { - {"index", index}, - {"score", score}, + {"index", index}, + {"score", score}, + {"tokens_evaluated", n_tokens}, }; } }; @@ -2034,6 +2037,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; + res->n_tokens = slot.n_prompt_tokens; for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { diff --git a/examples/server/tests/unit/test_rerank.py b/examples/server/tests/unit/test_rerank.py index 189bc4c96..7203d7943 100644 --- a/examples/server/tests/unit/test_rerank.py +++ b/examples/server/tests/unit/test_rerank.py @@ -53,3 +53,26 @@ def test_invalid_rerank_req(documents): }) assert res.status_code == 400 assert "error" in res.body + + +@pytest.mark.parametrize( + "query,doc1,doc2,n_tokens", + [ + ("Machine learning is", "A machine", "Learning is", 19), + ("Which city?", "Machine learning is ", "Paris, capitale de la", 26), + ] +) +def test_rerank_usage(query, doc1, doc2, n_tokens): + global server + server.start() + + res = server.make_request("POST", "/rerank", data={ + "query": query, + "documents": [ + doc1, + doc2, + ] + }) + assert res.status_code == 200 + assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] + assert res.body['usage']['prompt_tokens'] == n_tokens diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 9d281ee92..8fffe484a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -587,20 +587,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso static json format_response_rerank(const json & request, const json & ranks) { json data = json::array(); + int32_t n_tokens = 0; int i = 0; for (const auto & rank : ranks) { data.push_back(json{ {"index", i++}, {"relevance_score", json_value(rank, "score", 0.0)}, }); + + n_tokens += json_value(rank, "tokens_evaluated", 0); } json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, - {"usage", json { // TODO: fill - {"prompt_tokens", 0}, - {"total_tokens", 0} + {"usage", json { + {"prompt_tokens", n_tokens}, + {"total_tokens", n_tokens} }}, {"results", data} };