server: tests: assert embeddings are actually computed, make the embeddings endpoint configurable.
Add logs to investigate why the CI server test job is not starting
This commit is contained in:
parent
cba6d4ea17
commit
1bd07e56c4
2 changed files with 28 additions and 5 deletions
|
@ -8,6 +8,7 @@ Feature: llama.cpp server
|
|||
And 42 as server seed
|
||||
And 32 KV cache size
|
||||
And 1 slots
|
||||
And embeddings extraction
|
||||
And 32 server max tokens to predict
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
import re
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import closing
|
||||
from re import RegexFlag
|
||||
|
||||
|
@ -21,13 +22,14 @@ def step_server_config(context, server_fqdn, server_port):
|
|||
|
||||
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
||||
|
||||
context.server_continuous_batching = False
|
||||
context.model_alias = None
|
||||
context.n_ctx = None
|
||||
context.n_predict = None
|
||||
context.n_server_predict = None
|
||||
context.n_slots = None
|
||||
context.server_api_key = None
|
||||
context.server_continuous_batching = False
|
||||
context.server_embeddings = False
|
||||
context.server_seed = None
|
||||
context.user_api_key = None
|
||||
|
||||
|
@ -70,15 +72,26 @@ def step_server_n_predict(context, n_predict):
|
|||
def step_server_continuous_batching(context):
|
||||
context.server_continuous_batching = True
|
||||
|
||||
@step(u'embeddings extraction')
|
||||
def step_server_embeddings(context):
|
||||
context.server_embeddings = True
|
||||
|
||||
|
||||
@step(u"the server is starting")
|
||||
def step_start_server(context):
|
||||
start_server_background(context)
|
||||
attempts = 0
|
||||
while True:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
result = sock.connect_ex((context.server_fqdn, context.server_port))
|
||||
if result == 0:
|
||||
print("server started!")
|
||||
return
|
||||
attempts += 1
|
||||
if attempts > 20:
|
||||
assert False, "server not started"
|
||||
print("waiting for server to start...")
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@step(u"the server is {expecting_status}")
|
||||
|
@ -301,6 +314,11 @@ def step_compute_embedding(context):
|
|||
@step(u'embeddings are generated')
|
||||
def step_compute_embeddings(context):
|
||||
assert len(context.embeddings) > 0
|
||||
embeddings_computed = False
|
||||
for emb in context.embeddings:
|
||||
if emb != 0:
|
||||
embeddings_computed = True
|
||||
assert embeddings_computed, f"Embeddings: {context.embeddings}"
|
||||
|
||||
|
||||
@step(u'an OAI compatible embeddings computation request for')
|
||||
|
@ -436,7 +454,8 @@ async def oai_chat_completions(user_prompt,
|
|||
json=payload,
|
||||
headers=headers) as response:
|
||||
if enable_streaming:
|
||||
print("payload", payload)
|
||||
# FIXME: does not work; the server is generating only one token
|
||||
print("DEBUG payload", payload)
|
||||
assert response.status == 200
|
||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||
assert response.headers['Content-Type'] == "text/event-stream"
|
||||
|
@ -453,7 +472,7 @@ async def oai_chat_completions(user_prompt,
|
|||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
print(f"XXXXXXXXXXXXXXXXXcompletion_response: {completion_response}")
|
||||
print(f"DEBUG completion_response: {completion_response}")
|
||||
else:
|
||||
if expect_api_error is None or not expect_api_error:
|
||||
assert response.status == 200
|
||||
|
@ -500,7 +519,7 @@ async def oai_chat_completions(user_prompt,
|
|||
'predicted_n': chat_completion.usage.completion_tokens
|
||||
}
|
||||
}
|
||||
print("OAI response formatted to llama.cpp", completion_response)
|
||||
print("OAI response formatted to llama.cpp:", completion_response)
|
||||
return completion_response
|
||||
|
||||
|
||||
|
@ -567,7 +586,7 @@ async def wait_for_health_status(context,
|
|||
# Sometimes health requests are triggered after completions are predicted
|
||||
if expected_http_status_code == 503:
|
||||
if len(context.completions) == 0:
|
||||
print("\x1b[5;37;43mWARNING: forcing concurrents completions tasks,"
|
||||
print("\x1b[33;42mWARNING: forcing concurrents completions tasks,"
|
||||
" busy health check missed\x1b[0m")
|
||||
n_completions = await gather_concurrent_completions_tasks(context)
|
||||
if n_completions > 0:
|
||||
|
@ -604,6 +623,8 @@ def start_server_background(context):
|
|||
]
|
||||
if context.server_continuous_batching:
|
||||
server_args.append('--cont-batching')
|
||||
if context.server_embeddings:
|
||||
server_args.append('--embedding')
|
||||
if context.model_alias is not None:
|
||||
server_args.extend(['--alias', context.model_alias])
|
||||
if context.server_seed is not None:
|
||||
|
@ -620,3 +641,4 @@ def start_server_background(context):
|
|||
context.server_process = subprocess.Popen(
|
||||
[str(arg) for arg in [context.server_path, *server_args]],
|
||||
close_fds=True)
|
||||
print(f"server pid={context.server_process.pid}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue