server : fill usage info in embeddings response
This commit is contained in:
parent
4f51968aca
commit
357a7bac41
3 changed files with 42 additions and 5 deletions
|
@ -719,14 +719,17 @@ struct server_task_result_embd : server_task_result {
|
|||
int index = 0;
|
||||
std::vector<float> embedding;
|
||||
|
||||
int32_t n_tokens;
|
||||
|
||||
virtual int get_index() override {
|
||||
return index;
|
||||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
return json {
|
||||
{"index", index},
|
||||
{"embedding", embedding},
|
||||
{"index", index},
|
||||
{"embedding", embedding},
|
||||
{"tokens_evaluated", n_tokens},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
@ -1995,6 +1998,7 @@ struct server_context {
|
|||
auto res = std::make_unique<server_task_result_embd>();
|
||||
res->id = slot.id_task;
|
||||
res->index = slot.index;
|
||||
res->n_tokens = slot.n_prompt_tokens;
|
||||
|
||||
const int n_embd = llama_n_embd(model);
|
||||
|
||||
|
|
|
@ -97,3 +97,33 @@ def test_same_prompt_give_same_result():
|
|||
vi = res.body['data'][i]['embedding']
|
||||
for x, y in zip(v0, vi):
|
||||
assert abs(x - y) < EPSILON
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content,n_tokens",
|
||||
[
|
||||
("I believe the meaning of life is", 7),
|
||||
("This is a test", 4),
|
||||
]
|
||||
)
|
||||
def test_embedding_usage_single(content, n_tokens):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={"input": content})
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == n_tokens
|
||||
|
||||
|
||||
def test_embedding_usage_multiple():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/embeddings", data={
|
||||
"input": [
|
||||
"I believe the meaning of life is",
|
||||
"I believe the meaning of life is",
|
||||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
|
||||
assert res.body['usage']['prompt_tokens'] == 2 * 7
|
||||
|
|
|
@ -560,6 +560,7 @@ static json oaicompat_completion_params_parse(
|
|||
|
||||
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
|
||||
json data = json::array();
|
||||
int32_t n_tokens = 0;
|
||||
int i = 0;
|
||||
for (const auto & elem : embeddings) {
|
||||
data.push_back(json{
|
||||
|
@ -567,14 +568,16 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||
{"index", i++},
|
||||
{"object", "embedding"}
|
||||
});
|
||||
|
||||
n_tokens += json_value(elem, "tokens_evaluated", 0);
|
||||
}
|
||||
|
||||
json res = json {
|
||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||
{"object", "list"},
|
||||
{"usage", json { // TODO: fill
|
||||
{"prompt_tokens", 0},
|
||||
{"total_tokens", 0}
|
||||
{"usage", json {
|
||||
{"prompt_tokens", n_tokens},
|
||||
{"total_tokens", n_tokens}
|
||||
}},
|
||||
{"data", data}
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue