server : fill usage info in reranking response

This commit is contained in:
Krystian Chachuła 2024-12-16 14:45:06 +01:00 committed by Krystian Chachuła
parent 357a7bac41
commit 88cc9719c4
3 changed files with 35 additions and 5 deletions

View file

@ -738,14 +738,17 @@ struct server_task_result_rerank : server_task_result {
int index = 0; int index = 0;
float score = -1e6; float score = -1e6;
int32_t n_tokens;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
} }
virtual json to_json() override { virtual json to_json() override {
return json { return json {
{"index", index}, {"index", index},
{"score", score}, {"score", score},
{"tokens_evaluated", n_tokens},
}; };
} }
}; };
@ -2034,6 +2037,7 @@ struct server_context {
auto res = std::make_unique<server_task_result_rerank>(); auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.id_task; res->id = slot.id_task;
res->index = slot.index; res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
for (int i = 0; i < batch.n_tokens; ++i) { for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {

View file

@ -53,3 +53,26 @@ def test_invalid_rerank_req(documents):
}) })
assert res.status_code == 400 assert res.status_code == 400
assert "error" in res.body assert "error" in res.body
@pytest.mark.parametrize(
"query,doc1,doc2,n_tokens",
[
("Machine learning is", "A machine", "Learning is", 19),
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
]
)
def test_rerank_usage(query, doc1, doc2, n_tokens):
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": query,
"documents": [
doc1,
doc2,
]
})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens

View file

@ -587,20 +587,23 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
static json format_response_rerank(const json & request, const json & ranks) { static json format_response_rerank(const json & request, const json & ranks) {
json data = json::array(); json data = json::array();
int32_t n_tokens = 0;
int i = 0; int i = 0;
for (const auto & rank : ranks) { for (const auto & rank : ranks) {
data.push_back(json{ data.push_back(json{
{"index", i++}, {"index", i++},
{"relevance_score", json_value(rank, "score", 0.0)}, {"relevance_score", json_value(rank, "score", 0.0)},
}); });
n_tokens += json_value(rank, "tokens_evaluated", 0);
} }
json res = json { json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"}, {"object", "list"},
{"usage", json { // TODO: fill {"usage", json {
{"prompt_tokens", 0}, {"prompt_tokens", n_tokens},
{"total_tokens", 0} {"total_tokens", n_tokens}
}}, }},
{"results", data} {"results", data}
}; };