fix test case

This commit is contained in:
Xuan Son Nguyen 2024-12-17 21:36:50 +01:00
parent d4e0bad0ae
commit 9a566806f0

View file

@ -46,28 +46,32 @@ def test_embedding_multiple():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"content", "content,is_multi_prompt",
[ [
# single prompt # single prompt
"string", ("string", False),
[12, 34, 56], ([12, 34, 56], False),
[12, 34, "string", 56, 78], ([12, 34, "string", 56, 78], False),
# multiple prompts # multiple prompts
["string1", "string2"], (["string1", "string2"], True),
["string1", [12, 34, 56]], (["string1", [12, 34, 56]], True),
[[12, 34, 56], [12, 34, 56]], ([[12, 34, 56], [12, 34, 56]], True),
[[12, 34, 56], [12, "string", 34, 56]], ([[12, 34, 56], [12, "string", 34, 56]], True),
] ]
) )
def test_embedding_mixed_input(content): def test_embedding_mixed_input(content, is_multi_prompt: bool):
global server global server
server.start() server.start()
res = server.make_request("POST", "/embeddings", data={"content": content}) res = server.make_request("POST", "/embeddings", data={"content": content})
assert res.status_code == 200 assert res.status_code == 200
assert len(res.body['data']) == len(content) if is_multi_prompt:
for d in res.body['data']: assert len(res.body) == len(content)
assert 'embedding' in d for d in res.body:
assert len(d['embedding']) > 1 assert 'embedding' in d
assert len(d['embedding']) > 1
else:
assert 'embedding' in res.body
assert len(res.body['embedding']) > 1
def test_embedding_openai_library_single(): def test_embedding_openai_library_single():