server : be explicit about the pooling type in the tests

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-12-17 11:45:18 +02:00
parent 2dea48758e
commit 2a94c33028
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 9 additions and 3 deletions

View file

@ -14,6 +14,7 @@ def create_server():
def test_embedding_single(): def test_embedding_single():
global server global server
server.pooling = 'last'
server.start() server.start()
res = server.make_request("POST", "/embeddings", data={ res = server.make_request("POST", "/embeddings", data={
"input": "I believe the meaning of life is", "input": "I believe the meaning of life is",
@ -29,6 +30,7 @@ def test_embedding_single():
def test_embedding_multiple(): def test_embedding_multiple():
global server global server
server.pooling = 'last'
server.start() server.start()
res = server.make_request("POST", "/embeddings", data={ res = server.make_request("POST", "/embeddings", data={
"input": [ "input": [
@ -75,7 +77,8 @@ def test_embedding_mixed_input(content, is_multi_prompt: bool):
def test_embedding_pooling_none(): def test_embedding_pooling_none():
server = ServerPreset.bert_bge_small(pooling = 'none') global server
server.pooling = 'none'
server.start() server.start()
res = server.make_request("POST", "/embeddings", data={ res = server.make_request("POST", "/embeddings", data={
"input": "hello hello hello", "input": "hello hello hello",
@ -88,6 +91,7 @@ def test_embedding_pooling_none():
def test_embedding_openai_library_single(): def test_embedding_openai_library_single():
global server global server
server.pooling = 'last'
server.start() server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is") res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
@ -97,6 +101,7 @@ def test_embedding_openai_library_single():
def test_embedding_openai_library_multiple(): def test_embedding_openai_library_multiple():
global server global server
server.pooling = 'last'
server.start() server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.embeddings.create(model="text-embedding-3-small", input=[ res = client.embeddings.create(model="text-embedding-3-small", input=[
@ -112,6 +117,7 @@ def test_embedding_openai_library_multiple():
def test_embedding_error_prompt_too_long(): def test_embedding_error_prompt_too_long():
global server global server
server.pooling = 'last'
server.start() server.start()
res = server.make_request("POST", "/embeddings", data={ res = server.make_request("POST", "/embeddings", data={
"input": "This is a test " * 512, "input": "This is a test " * 512,
@ -121,6 +127,7 @@ def test_embedding_error_prompt_too_long():
def test_same_prompt_give_same_result(): def test_same_prompt_give_same_result():
server.pooling = 'last'
server.start() server.start()
res = server.make_request("POST", "/embeddings", data={ res = server.make_request("POST", "/embeddings", data={
"input": [ "input": [

View file

@ -275,7 +275,7 @@ class ServerPreset:
return server return server
@staticmethod @staticmethod
def bert_bge_small(pooling = 'last') -> ServerProcess: def bert_bge_small() -> ServerProcess:
server = ServerProcess() server = ServerProcess()
server.model_hf_repo = "ggml-org/models" server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
@ -286,7 +286,6 @@ class ServerPreset:
server.n_slots = 2 server.n_slots = 2
server.seed = 42 server.seed = 42
server.server_embeddings = True server.server_embeddings = True
server.pooling = pooling
return server return server
@staticmethod @staticmethod