tests : update server tests

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-12-18 11:33:46 +02:00
parent 87df60166d
commit 2a5510ed82
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 16 additions and 15 deletions

View file

@ -746,7 +746,6 @@ struct server_task_result_embd : server_task_result {
return json { return json {
{"index", index}, {"index", index},
{"embedding", embedding}, {"embedding", embedding},
{"tokens_evaluated", n_tokens},
}; };
} }
@ -754,6 +753,7 @@ struct server_task_result_embd : server_task_result {
return json { return json {
{"index", index}, {"index", index},
{"embedding", embedding[0]}, {"embedding", embedding[0]},
{"tokens_evaluated", n_tokens},
}; };
} }
}; };

View file

@ -48,7 +48,7 @@ def test_embedding_multiple():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"content,is_multi_prompt", "input,is_multi_prompt",
[ [
# single prompt # single prompt
("string", False), ("string", False),
@ -61,19 +61,20 @@ def test_embedding_multiple():
([[12, 34, 56], [12, "string", 34, 56]], True), ([[12, 34, 56], [12, "string", 34, 56]], True),
] ]
) )
def test_embedding_mixed_input(content, is_multi_prompt: bool): def test_embedding_mixed_input(input, is_multi_prompt: bool):
global server global server
server.start() server.start()
res = server.make_request("POST", "/embeddings", data={"content": content}) res = server.make_request("POST", "/v1/embeddings", data={"input": input})
assert res.status_code == 200 assert res.status_code == 200
data = res.body['data']
if is_multi_prompt: if is_multi_prompt:
assert len(res.body) == len(content) assert len(data) == len(input)
for d in res.body: for d in data:
assert 'embedding' in d assert 'embedding' in d
assert len(d['embedding']) > 1 assert len(d['embedding']) > 1
else: else:
assert 'embedding' in res.body assert 'embedding' in data[0]
assert len(res.body['embedding']) > 1 assert len(data[0]['embedding']) > 1
def test_embedding_pooling_none(): def test_embedding_pooling_none():
@ -85,7 +86,7 @@ def test_embedding_pooling_none():
}) })
assert res.status_code == 200 assert res.status_code == 200
assert 'embedding' in res.body[0] assert 'embedding' in res.body[0]
assert len(res.body[0]['embedding']) == 3 assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
# make sure embedding vector is not normalized # make sure embedding vector is not normalized
for x in res.body[0]['embedding']: for x in res.body[0]['embedding']:
@ -172,7 +173,7 @@ def test_same_prompt_give_same_result():
def test_embedding_usage_single(content, n_tokens): def test_embedding_usage_single(content, n_tokens):
global server global server
server.start() server.start()
res = server.make_request("POST", "/embeddings", data={"input": content}) res = server.make_request("POST", "/v1/embeddings", data={"input": content})
assert res.status_code == 200 assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens assert res.body['usage']['prompt_tokens'] == n_tokens
@ -181,7 +182,7 @@ def test_embedding_usage_single(content, n_tokens):
def test_embedding_usage_multiple(): def test_embedding_usage_multiple():
global server global server
server.start() server.start()
res = server.make_request("POST", "/embeddings", data={ res = server.make_request("POST", "/v1/embeddings", data={
"input": [ "input": [
"I believe the meaning of life is", "I believe the meaning of life is",
"I believe the meaning of life is", "I believe the meaning of life is",