server : (embeddings) using same format for "input" and "content" (#10872)

* server : (embeddings) using same format for "input" and "content"

* fix test case

* handle empty input case

* fix test
This commit is contained in:
Xuan Son Nguyen 2024-12-18 09:55:09 +01:00 committed by GitHub
parent 6b064c92b4
commit 46828872c3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 47 additions and 9 deletions

View file

@ -45,6 +45,35 @@ def test_embedding_multiple():
assert len(d['embedding']) > 1
@pytest.mark.parametrize(
"content,is_multi_prompt",
[
# single prompt
("string", False),
([12, 34, 56], False),
([12, 34, "string", 56, 78], False),
# multiple prompts
(["string1", "string2"], True),
(["string1", [12, 34, 56]], True),
([[12, 34, 56], [12, 34, 56]], True),
([[12, 34, 56], [12, "string", 34, 56]], True),
]
)
def test_embedding_mixed_input(content, is_multi_prompt: bool):
global server
server.start()
res = server.make_request("POST", "/embeddings", data={"content": content})
assert res.status_code == 200
if is_multi_prompt:
assert len(res.body) == len(content)
for d in res.body:
assert 'embedding' in d
assert len(d['embedding']) > 1
else:
assert 'embedding' in res.body
assert len(res.body['embedding']) > 1
def test_embedding_openai_library_single():
global server
server.start()
@ -102,8 +131,8 @@ def test_same_prompt_give_same_result():
@pytest.mark.parametrize(
"content,n_tokens",
[
("I believe the meaning of life is", 7),
("This is a test", 4),
("I believe the meaning of life is", 9),
("This is a test", 6),
]
)
def test_embedding_usage_single(content, n_tokens):
@ -126,4 +155,4 @@ def test_embedding_usage_multiple():
})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == 2 * 7
assert res.body['usage']['prompt_tokens'] == 2 * 9