server: tests: assert embeddings are actually computed, make the embeddings endpoint configurable.

Add logs to investigate why the CI server test job is not starting
This commit is contained in:
Pierrick HYMBERT 2024-02-23 01:25:08 +01:00
parent cba6d4ea17
commit 1bd07e56c4
2 changed files with 28 additions and 5 deletions

View file

@ -8,6 +8,7 @@ Feature: llama.cpp server
And 42 as server seed
And 32 KV cache size
And 1 slots
And embeddings extraction
And 32 server max tokens to predict
Then the server is starting
Then the server is healthy

View file

@ -4,6 +4,7 @@ import os
import re
import socket
import subprocess
import time
from contextlib import closing
from re import RegexFlag
@ -21,13 +22,14 @@ def step_server_config(context, server_fqdn, server_port):
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
context.server_continuous_batching = False
context.model_alias = None
context.n_ctx = None
context.n_predict = None
context.n_server_predict = None
context.n_slots = None
context.server_api_key = None
context.server_continuous_batching = False
context.server_embeddings = False
context.server_seed = None
context.user_api_key = None
@ -70,15 +72,26 @@ def step_server_n_predict(context, n_predict):
def step_server_continuous_batching(context):
context.server_continuous_batching = True
@step(u'embeddings extraction')
def step_server_embeddings(context):
context.server_embeddings = True
@step(u"the server is starting")
def step_start_server(context):
start_server_background(context)
attempts = 0
while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
result = sock.connect_ex((context.server_fqdn, context.server_port))
if result == 0:
print("server started!")
return
attempts += 1
if attempts > 20:
assert False, "server not started"
print("waiting for server to start...")
time.sleep(0.1)
@step(u"the server is {expecting_status}")
@ -301,6 +314,11 @@ def step_compute_embedding(context):
@step(u'embeddings are generated')
def step_compute_embeddings(context):
assert len(context.embeddings) > 0
embeddings_computed = False
for emb in context.embeddings:
if emb != 0:
embeddings_computed = True
assert embeddings_computed, f"Embeddings: {context.embeddings}"
@step(u'an OAI compatible embeddings computation request for')
@ -436,7 +454,8 @@ async def oai_chat_completions(user_prompt,
json=payload,
headers=headers) as response:
if enable_streaming:
print("payload", payload)
# FIXME: does not work; the server is generating only one token
print("DEBUG payload", payload)
assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin
assert response.headers['Content-Type'] == "text/event-stream"
@ -453,7 +472,7 @@ async def oai_chat_completions(user_prompt,
if 'content' in delta:
completion_response['content'] += delta['content']
completion_response['timings']['predicted_n'] += 1
print(f"XXXXXXXXXXXXXXXXXcompletion_response: {completion_response}")
print(f"DEBUG completion_response: {completion_response}")
else:
if expect_api_error is None or not expect_api_error:
assert response.status == 200
@ -500,7 +519,7 @@ async def oai_chat_completions(user_prompt,
'predicted_n': chat_completion.usage.completion_tokens
}
}
print("OAI response formatted to llama.cpp", completion_response)
print("OAI response formatted to llama.cpp:", completion_response)
return completion_response
@ -567,7 +586,7 @@ async def wait_for_health_status(context,
# Sometimes health requests are triggered after completions are predicted
if expected_http_status_code == 503:
if len(context.completions) == 0:
print("\x1b[5;37;43mWARNING: forcing concurrents completions tasks,"
print("\x1b[33;42mWARNING: forcing concurrents completions tasks,"
" busy health check missed\x1b[0m")
n_completions = await gather_concurrent_completions_tasks(context)
if n_completions > 0:
@ -604,6 +623,8 @@ def start_server_background(context):
]
if context.server_continuous_batching:
server_args.append('--cont-batching')
if context.server_embeddings:
server_args.append('--embedding')
if context.model_alias is not None:
server_args.extend(['--alias', context.model_alias])
if context.server_seed is not None:
@ -620,3 +641,4 @@ def start_server_background(context):
context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]],
close_fds=True)
print(f"server pid={context.server_process.pid}")