From 9a566806f02ef3337759716f40430fb478478ddb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 17 Dec 2024 21:36:50 +0100 Subject: [PATCH] fix test case --- examples/server/tests/unit/test_embedding.py | 30 +++++++++++--------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index c66ee8a45..f1019a6b6 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -46,28 +46,32 @@ def test_embedding_multiple(): @pytest.mark.parametrize( - "content", + "content,is_multi_prompt", [ # single prompt - "string", - [12, 34, 56], - [12, 34, "string", 56, 78], + ("string", False), + ([12, 34, 56], False), + ([12, 34, "string", 56, 78], False), # multiple prompts - ["string1", "string2"], - ["string1", [12, 34, 56]], - [[12, 34, 56], [12, 34, 56]], - [[12, 34, 56], [12, "string", 34, 56]], + (["string1", "string2"], True), + (["string1", [12, 34, 56]], True), + ([[12, 34, 56], [12, 34, 56]], True), + ([[12, 34, 56], [12, "string", 34, 56]], True), ] ) -def test_embedding_mixed_input(content): +def test_embedding_mixed_input(content, is_multi_prompt: bool): global server server.start() res = server.make_request("POST", "/embeddings", data={"content": content}) assert res.status_code == 200 - assert len(res.body['data']) == len(content) - for d in res.body['data']: - assert 'embedding' in d - assert len(d['embedding']) > 1 + if is_multi_prompt: + assert len(res.body) == len(content) + for d in res.body: + assert 'embedding' in d + assert len(d['embedding']) > 1 + else: + assert 'embedding' in res.body + assert len(res.body['embedding']) > 1 def test_embedding_openai_library_single():