server: tests: assert embeddings are actually computed, make the embeddings endpoint configurable.

Add logs to investigate why the CI server test job is not starting
This commit is contained in:
Pierrick HYMBERT 2024-02-23 01:25:08 +01:00
parent cba6d4ea17
commit 1bd07e56c4
2 changed files with 28 additions and 5 deletions

View file

@ -8,6 +8,7 @@ Feature: llama.cpp server
And 42 as server seed And 42 as server seed
And 32 KV cache size And 32 KV cache size
And 1 slots And 1 slots
And embeddings extraction
And 32 server max tokens to predict And 32 server max tokens to predict
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy

View file

@ -4,6 +4,7 @@ import os
import re import re
import socket import socket
import subprocess import subprocess
import time
from contextlib import closing from contextlib import closing
from re import RegexFlag from re import RegexFlag
@ -21,13 +22,14 @@ def step_server_config(context, server_fqdn, server_port):
context.base_url = f'http://{context.server_fqdn}:{context.server_port}' context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
context.server_continuous_batching = False
context.model_alias = None context.model_alias = None
context.n_ctx = None context.n_ctx = None
context.n_predict = None context.n_predict = None
context.n_server_predict = None context.n_server_predict = None
context.n_slots = None context.n_slots = None
context.server_api_key = None context.server_api_key = None
context.server_continuous_batching = False
context.server_embeddings = False
context.server_seed = None context.server_seed = None
context.user_api_key = None context.user_api_key = None
@ -70,15 +72,26 @@ def step_server_n_predict(context, n_predict):
def step_server_continuous_batching(context): def step_server_continuous_batching(context):
context.server_continuous_batching = True context.server_continuous_batching = True
@step(u'embeddings extraction')
def step_server_embeddings(context):
context.server_embeddings = True
@step(u"the server is starting") @step(u"the server is starting")
def step_start_server(context): def step_start_server(context):
start_server_background(context) start_server_background(context)
attempts = 0
while True: while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
result = sock.connect_ex((context.server_fqdn, context.server_port)) result = sock.connect_ex((context.server_fqdn, context.server_port))
if result == 0: if result == 0:
print("server started!")
return return
attempts += 1
if attempts > 20:
assert False, "server not started"
print("waiting for server to start...")
time.sleep(0.1)
@step(u"the server is {expecting_status}") @step(u"the server is {expecting_status}")
@ -301,6 +314,11 @@ def step_compute_embedding(context):
@step(u'embeddings are generated') @step(u'embeddings are generated')
def step_compute_embeddings(context): def step_compute_embeddings(context):
assert len(context.embeddings) > 0 assert len(context.embeddings) > 0
embeddings_computed = False
for emb in context.embeddings:
if emb != 0:
embeddings_computed = True
assert embeddings_computed, f"Embeddings: {context.embeddings}"
@step(u'an OAI compatible embeddings computation request for') @step(u'an OAI compatible embeddings computation request for')
@ -436,7 +454,8 @@ async def oai_chat_completions(user_prompt,
json=payload, json=payload,
headers=headers) as response: headers=headers) as response:
if enable_streaming: if enable_streaming:
print("payload", payload) # FIXME: does not work; the server is generating only one token
print("DEBUG payload", payload)
assert response.status == 200 assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin assert response.headers['Access-Control-Allow-Origin'] == origin
assert response.headers['Content-Type'] == "text/event-stream" assert response.headers['Content-Type'] == "text/event-stream"
@ -453,7 +472,7 @@ async def oai_chat_completions(user_prompt,
if 'content' in delta: if 'content' in delta:
completion_response['content'] += delta['content'] completion_response['content'] += delta['content']
completion_response['timings']['predicted_n'] += 1 completion_response['timings']['predicted_n'] += 1
print(f"XXXXXXXXXXXXXXXXXcompletion_response: {completion_response}") print(f"DEBUG completion_response: {completion_response}")
else: else:
if expect_api_error is None or not expect_api_error: if expect_api_error is None or not expect_api_error:
assert response.status == 200 assert response.status == 200
@ -500,7 +519,7 @@ async def oai_chat_completions(user_prompt,
'predicted_n': chat_completion.usage.completion_tokens 'predicted_n': chat_completion.usage.completion_tokens
} }
} }
print("OAI response formatted to llama.cpp", completion_response) print("OAI response formatted to llama.cpp:", completion_response)
return completion_response return completion_response
@ -567,7 +586,7 @@ async def wait_for_health_status(context,
# Sometimes health requests are triggered after completions are predicted # Sometimes health requests are triggered after completions are predicted
if expected_http_status_code == 503: if expected_http_status_code == 503:
if len(context.completions) == 0: if len(context.completions) == 0:
print("\x1b[5;37;43mWARNING: forcing concurrents completions tasks," print("\x1b[33;42mWARNING: forcing concurrents completions tasks,"
" busy health check missed\x1b[0m") " busy health check missed\x1b[0m")
n_completions = await gather_concurrent_completions_tasks(context) n_completions = await gather_concurrent_completions_tasks(context)
if n_completions > 0: if n_completions > 0:
@ -604,6 +623,8 @@ def start_server_background(context):
] ]
if context.server_continuous_batching: if context.server_continuous_batching:
server_args.append('--cont-batching') server_args.append('--cont-batching')
if context.server_embeddings:
server_args.append('--embedding')
if context.model_alias is not None: if context.model_alias is not None:
server_args.extend(['--alias', context.model_alias]) server_args.extend(['--alias', context.model_alias])
if context.server_seed is not None: if context.server_seed is not None:
@ -620,3 +641,4 @@ def start_server_background(context):
context.server_process = subprocess.Popen( context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]], [str(arg) for arg in [context.server_path, *server_args]],
close_fds=True) close_fds=True)
print(f"server pid={context.server_process.pid}")