server: tests: passkey challenge / self-extend with context shift demo (#5832)
* server: tests: add models endpoint scenario * server: /v1/models add some metadata * server: tests: add debug field in context before scenario * server: tests: download model from HF, add batch size * server: tests: add passkey test * server: tests: add group attention params * server: do not truncate prompt tokens if self-extend through group attention is enabled * server: logs: do not truncate log values * server: tests - passkey - first good working value of nga * server: tests: fix server timeout * server: tests: fix passkey, add doc, fix regex content matching, fix timeout * server: tests: fix regex content matching * server: tests: schedule slow tests on master * server: metrics: fix when no prompt processed * server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1 * server: tests: increase timeout for completion * server: tests: keep only the PHI-2 test * server: tests: passkey add a negative test
This commit is contained in:
parent
4a6e2d6142
commit
9731134296
14 changed files with 363 additions and 112 deletions
|
@ -13,6 +13,7 @@ import aiohttp
|
|||
import openai
|
||||
from behave import step
|
||||
from behave.api.async_step import async_run_until_complete
|
||||
from huggingface_hub import hf_hub_download
|
||||
from prometheus_client import parser
|
||||
|
||||
|
||||
|
@ -26,17 +27,23 @@ def step_server_config(context, server_fqdn, server_port):
|
|||
|
||||
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
||||
|
||||
context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
|
||||
context.model_alias = None
|
||||
context.n_batch = None
|
||||
context.n_ctx = None
|
||||
context.n_ga = None
|
||||
context.n_ga_w = None
|
||||
context.n_gpu_layer = None
|
||||
context.n_predict = None
|
||||
context.n_server_predict = None
|
||||
context.n_slots = None
|
||||
context.prompt_prefix = None
|
||||
context.prompt_suffix = None
|
||||
context.server_api_key = None
|
||||
context.server_continuous_batching = False
|
||||
context.server_embeddings = False
|
||||
context.server_metrics = False
|
||||
context.server_process = None
|
||||
context.seed = None
|
||||
context.server_seed = None
|
||||
context.user_api_key = None
|
||||
|
||||
|
@ -45,9 +52,11 @@ def step_server_config(context, server_fqdn, server_port):
|
|||
context.prompts = []
|
||||
|
||||
|
||||
@step(u'a model file {model_file}')
|
||||
def step_model_file(context, model_file):
|
||||
context.model_file = model_file
|
||||
@step(u'a model file {hf_file} from HF repo {hf_repo}')
|
||||
def step_download_hf_model(context, hf_file, hf_repo):
|
||||
context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
|
||||
if context.debug:
|
||||
print(f"model file: {context.model_file}\n")
|
||||
|
||||
|
||||
@step(u'a model alias {model_alias}')
|
||||
|
@ -55,24 +64,34 @@ def step_model_alias(context, model_alias):
|
|||
context.model_alias = model_alias
|
||||
|
||||
|
||||
@step(u'{seed} as server seed')
|
||||
@step(u'{seed:d} as server seed')
|
||||
def step_seed(context, seed):
|
||||
context.server_seed = int(seed)
|
||||
context.server_seed = seed
|
||||
|
||||
|
||||
@step(u'{n_ctx} KV cache size')
|
||||
@step(u'{ngl:d} GPU offloaded layers')
|
||||
def step_n_gpu_layer(context, ngl):
|
||||
if 'N_GPU_LAYERS' in os.environ:
|
||||
new_ngl = int(os.environ['N_GPU_LAYERS'])
|
||||
if context.debug:
|
||||
print(f"-ngl upgraded from {ngl} to {new_ngl}")
|
||||
ngl = new_ngl
|
||||
context.n_gpu_layer = ngl
|
||||
|
||||
|
||||
@step(u'{n_ctx:d} KV cache size')
|
||||
def step_n_ctx(context, n_ctx):
|
||||
context.n_ctx = int(n_ctx)
|
||||
context.n_ctx = n_ctx
|
||||
|
||||
|
||||
@step(u'{n_slots} slots')
|
||||
@step(u'{n_slots:d} slots')
|
||||
def step_n_slots(context, n_slots):
|
||||
context.n_slots = int(n_slots)
|
||||
context.n_slots = n_slots
|
||||
|
||||
|
||||
@step(u'{n_predict} server max tokens to predict')
|
||||
@step(u'{n_predict:d} server max tokens to predict')
|
||||
def step_server_n_predict(context, n_predict):
|
||||
context.n_server_predict = int(n_predict)
|
||||
context.n_server_predict = n_predict
|
||||
|
||||
|
||||
@step(u'continuous batching')
|
||||
|
@ -116,11 +135,13 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
|
|||
|
||||
case 'ready' | 'idle':
|
||||
await wait_for_health_status(context, context.base_url, 200, 'ok',
|
||||
timeout=10,
|
||||
params={'fail_on_no_slot': 0, 'include_slots': 0},
|
||||
slots_idle=context.n_slots,
|
||||
slots_processing=0,
|
||||
expected_slots=[{'id': slot_id, 'state': 0}
|
||||
for slot_id in range(context.n_slots)])
|
||||
for slot_id in
|
||||
range(context.n_slots if context.n_slots else 1)])
|
||||
case 'busy':
|
||||
await wait_for_health_status(context, context.base_url, 503,
|
||||
'no slot available',
|
||||
|
@ -128,7 +149,8 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
|
|||
slots_idle=0,
|
||||
slots_processing=context.n_slots,
|
||||
expected_slots=[{'id': slot_id, 'state': 1}
|
||||
for slot_id in range(context.n_slots)])
|
||||
for slot_id in
|
||||
range(context.n_slots if context.n_slots else 1)])
|
||||
case _:
|
||||
assert False, "unknown status"
|
||||
|
||||
|
@ -157,24 +179,24 @@ async def step_request_completion(context, api_error):
|
|||
context.base_url,
|
||||
debug=context.debug,
|
||||
n_predict=context.n_predict,
|
||||
server_seed=context.server_seed,
|
||||
seed=await completions_seed(context),
|
||||
expect_api_error=expect_api_error,
|
||||
user_api_key=context.user_api_key)
|
||||
context.tasks_result.append(completion)
|
||||
if context.debug:
|
||||
print(f"Completion response: {completion}")
|
||||
print(f"Completion response: {completion}\n")
|
||||
if expect_api_error:
|
||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||
|
||||
|
||||
@step(u'{predicted_n} tokens are predicted matching {re_content}')
|
||||
@step(u'{predicted_n:d} tokens are predicted matching {re_content}')
|
||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content)
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
|
||||
|
||||
|
||||
@step(u'{predicted_n} tokens are predicted')
|
||||
@step(u'{predicted_n:d} tokens are predicted')
|
||||
def step_n_tokens_predicted(context, predicted_n):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n))
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
|
||||
|
||||
|
||||
@step(u'a user prompt {user_prompt}')
|
||||
|
@ -192,9 +214,9 @@ def step_model(context, model):
|
|||
context.model = model
|
||||
|
||||
|
||||
@step(u'{max_tokens} max tokens to predict')
|
||||
@step(u'{max_tokens:d} max tokens to predict')
|
||||
def step_max_tokens(context, max_tokens):
|
||||
context.n_predict = int(max_tokens)
|
||||
context.n_predict = max_tokens
|
||||
|
||||
|
||||
@step(u'streaming is {enable_streaming}')
|
||||
|
@ -222,11 +244,70 @@ def step_server_api_key(context, server_api_key):
|
|||
context.server_api_key = server_api_key
|
||||
|
||||
|
||||
@step(u'{n_junk:d} as number of junk')
|
||||
def step_n_junk(context, n_junk):
|
||||
context.n_junk = n_junk
|
||||
|
||||
|
||||
@step(u'{n_batch:d} as batch size')
|
||||
def step_n_batch(context, n_batch):
|
||||
context.n_batch = n_batch
|
||||
|
||||
|
||||
@step(u'{seed:d} as seed')
|
||||
def step_seed(context, seed):
|
||||
context.seed = seed
|
||||
|
||||
|
||||
@step(u'a prefix prompt')
|
||||
def step_prompt_prefix(context):
|
||||
context.prompt_prefix = context.text
|
||||
|
||||
|
||||
@step(u'a junk suffix prompt')
|
||||
def step_prompt_junk_suffix(context):
|
||||
context.prompt_junk_suffix = context.text
|
||||
|
||||
|
||||
@step(u'a suffix prompt')
|
||||
def step_prompt_suffix(context):
|
||||
context.prompt_suffix = context.text
|
||||
|
||||
|
||||
@step(u'{n_ga:d} group attention factor'
|
||||
u' to extend context size through self-extend')
|
||||
def step_impl(context, n_ga):
|
||||
context.n_ga = n_ga
|
||||
|
||||
|
||||
@step(u'{n_ga_w:d} group attention width to extend context size through self-extend')
|
||||
def step_impl(context, n_ga_w):
|
||||
context.n_ga_w = n_ga_w
|
||||
|
||||
|
||||
@step(u'a passkey prompt template')
|
||||
def step_prompt_passkey(context):
|
||||
context.prompt_passkey = context.text
|
||||
|
||||
|
||||
@step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
|
||||
def step_prompt_passkey(context, passkey, i_pos):
|
||||
prompt = ""
|
||||
for i in range(context.n_junk):
|
||||
if i % context.n_junk == i_pos:
|
||||
prompt += context.prompt_passkey # the passkey is already substituted
|
||||
prompt += context.prompt_junk_suffix
|
||||
if context.debug:
|
||||
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
|
||||
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
|
||||
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
|
||||
|
||||
|
||||
@step(u'an OAI compatible chat completions request with {api_error} api error')
|
||||
@async_run_until_complete
|
||||
async def step_oai_chat_completions(context, api_error):
|
||||
if context.debug:
|
||||
print(f"Submitting OAI compatible completions request...")
|
||||
print(f"Submitting OAI compatible completions request...\n")
|
||||
expect_api_error = api_error == 'raised'
|
||||
completion = await oai_chat_completions(context.prompts.pop(),
|
||||
context.system_prompt,
|
||||
|
@ -241,8 +322,7 @@ async def step_oai_chat_completions(context, api_error):
|
|||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
|
||||
server_seed=context.server_seed
|
||||
if hasattr(context, 'server_seed') else None,
|
||||
seed=await completions_seed(context),
|
||||
|
||||
user_api_key=context.user_api_key
|
||||
if hasattr(context, 'user_api_key') else None,
|
||||
|
@ -276,8 +356,10 @@ async def step_concurrent_completion_requests(context):
|
|||
# prompt is inserted automatically
|
||||
context.base_url,
|
||||
debug=context.debug,
|
||||
prompt_prefix=context.prompt_prefix,
|
||||
prompt_suffix=context.prompt_suffix,
|
||||
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
||||
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
||||
seed=await completions_seed(context),
|
||||
user_api_key=context.user_api_key if hasattr(context,
|
||||
'user_api_key') else None)
|
||||
|
||||
|
@ -297,8 +379,7 @@ async def step_oai_chat_completions(context):
|
|||
if hasattr(context, 'n_predict') else None,
|
||||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
server_seed=context.server_seed
|
||||
if hasattr(context, 'server_seed') else None,
|
||||
seed=await completions_seed(context),
|
||||
user_api_key=context.user_api_key
|
||||
if hasattr(context, 'user_api_key') else None)
|
||||
|
||||
|
@ -318,7 +399,9 @@ async def step_oai_chat_completions(context):
|
|||
if hasattr(context, 'n_predict') else None,
|
||||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
server_seed=context.server_seed
|
||||
seed=context.seed
|
||||
if hasattr(context, 'seed') else
|
||||
context.server_seed
|
||||
if hasattr(context, 'server_seed') else None,
|
||||
user_api_key=context.user_api_key
|
||||
if hasattr(context, 'user_api_key') else None)
|
||||
|
@ -330,11 +413,10 @@ async def step_all_prompts_are_predicted(context):
|
|||
await all_prompts_are_predicted(context)
|
||||
|
||||
|
||||
@step(u'all prompts are predicted with {n_predict} tokens')
|
||||
@step(u'all prompts are predicted with {n_expected_predicted:d} tokens')
|
||||
@async_run_until_complete
|
||||
async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict):
|
||||
expected_predicted_n = int(n_predict)
|
||||
await all_prompts_are_predicted(context, expected_predicted_n)
|
||||
async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted):
|
||||
await all_prompts_are_predicted(context, n_expected_predicted)
|
||||
|
||||
|
||||
async def all_prompts_are_predicted(context, expected_predicted_n=None):
|
||||
|
@ -464,6 +546,8 @@ async def step_prometheus_metrics_exported(context):
|
|||
assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4"
|
||||
metrics_raw = await metrics_response.text()
|
||||
metric_exported = False
|
||||
if context.debug:
|
||||
print(f"/metrics answer:\n{metrics_raw}\n")
|
||||
for metric in parser.text_string_to_metric_families(metrics_raw):
|
||||
match metric.name:
|
||||
case "llamacpp:kv_cache_usage_ratio":
|
||||
|
@ -472,6 +556,37 @@ async def step_prometheus_metrics_exported(context):
|
|||
assert metric_exported, "No metrics exported"
|
||||
|
||||
|
||||
@step(u'available models')
|
||||
def step_available_models(context):
|
||||
# openai client always expects an api_key
|
||||
openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
|
||||
openai.api_base = f'{context.base_url}/v1'
|
||||
context.models = openai.Model.list().data
|
||||
|
||||
|
||||
@step(u'{n_model:d} models are supported')
|
||||
def step_supported_models(context, n_model):
|
||||
if context.debug:
|
||||
print("server models available:", context.models)
|
||||
assert len(context.models) == n_model
|
||||
|
||||
|
||||
@step(u'model {i_model:d} is {param} {preposition} {param_value}')
|
||||
def step_supported_models(context, i_model, param, preposition, param_value):
|
||||
assert i_model < len(context.models)
|
||||
model = context.models[i_model]
|
||||
|
||||
param_value = param_value.split(' ', 1)[0]
|
||||
match param:
|
||||
case 'identified':
|
||||
value = model.id
|
||||
case 'trained':
|
||||
value = str(model.meta.n_ctx_train)
|
||||
case _:
|
||||
assert False, "param {param} not supported"
|
||||
assert param_value == value, f"model param {param} {value} != {param_value}"
|
||||
|
||||
|
||||
async def concurrent_requests(context, f_completion, *args, **kwargs):
|
||||
n_prompts = len(context.prompts)
|
||||
if context.debug:
|
||||
|
@ -486,8 +601,10 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
|
|||
async def request_completion(prompt,
|
||||
base_url,
|
||||
debug=False,
|
||||
prompt_prefix=None,
|
||||
prompt_suffix=None,
|
||||
n_predict=None,
|
||||
server_seed=None,
|
||||
seed=None,
|
||||
expect_api_error=None,
|
||||
user_api_key=None):
|
||||
if debug:
|
||||
|
@ -504,11 +621,14 @@ async def request_completion(prompt,
|
|||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{base_url}/completion',
|
||||
json={
|
||||
"input_prefix": prompt_prefix,
|
||||
"prompt": prompt,
|
||||
"n_predict": int(n_predict) if n_predict is not None else -1,
|
||||
"seed": server_seed if server_seed is not None else 42
|
||||
"input_suffix": prompt_suffix,
|
||||
"n_predict": n_predict if n_predict is not None else -1,
|
||||
"seed": seed if seed is not None else 42
|
||||
},
|
||||
headers=headers) as response:
|
||||
headers=headers,
|
||||
timeout=3600) as response:
|
||||
if expect_api_error is None or not expect_api_error:
|
||||
assert response.status == 200
|
||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||
|
@ -526,14 +646,14 @@ async def oai_chat_completions(user_prompt,
|
|||
model=None,
|
||||
n_predict=None,
|
||||
enable_streaming=None,
|
||||
server_seed=None,
|
||||
seed=None,
|
||||
user_api_key=None,
|
||||
expect_api_error=None):
|
||||
if debug:
|
||||
print(f"Sending OAI Chat completions request: {user_prompt}")
|
||||
# openai client always expects an api key
|
||||
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||
seed = server_seed if server_seed is not None else 42
|
||||
seed = seed if seed is not None else 42
|
||||
enable_streaming = enable_streaming if enable_streaming is not None else False
|
||||
payload = {
|
||||
"messages": [
|
||||
|
@ -692,20 +812,32 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
|
|||
content = completion_response['content']
|
||||
n_predicted = completion_response['timings']['predicted_n']
|
||||
assert len(content) > 0, "no token predicted"
|
||||
if expected_predicted_n is not None:
|
||||
if re_content is not None:
|
||||
p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
|
||||
matches = p.finditer(content)
|
||||
last_match = 0
|
||||
highlighted = ''
|
||||
for match in matches:
|
||||
start, end = match.span()
|
||||
highlighted += content[last_match: start]
|
||||
highlighted += '\x1b[33m'
|
||||
highlighted += content[start: end]
|
||||
highlighted += '\x1b[0m'
|
||||
last_match = end
|
||||
highlighted += content[last_match:]
|
||||
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
|
||||
print(f"Checking completion response: {highlighted}\n")
|
||||
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
|
||||
if expected_predicted_n and expected_predicted_n > 0:
|
||||
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
||||
f' {n_predicted} <> {expected_predicted_n}')
|
||||
if re_content is not None:
|
||||
re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
|
||||
assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
|
||||
f'invalid tokens predicted:'
|
||||
f' ```\n{content}\n``` do not match /{re_content}/')
|
||||
|
||||
|
||||
|
||||
async def gather_tasks_results(context):
|
||||
n_tasks = len(context.concurrent_tasks)
|
||||
if context.debug:
|
||||
print(f"Waiting for all {n_tasks} tasks results...")
|
||||
print(f"Waiting for all {n_tasks} tasks results...\n")
|
||||
for task_no in range(n_tasks):
|
||||
context.tasks_result.append(await context.concurrent_tasks.pop())
|
||||
n_completions = len(context.tasks_result)
|
||||
|
@ -716,15 +848,13 @@ async def wait_for_health_status(context,
|
|||
base_url,
|
||||
expected_http_status_code,
|
||||
expected_health_status,
|
||||
timeout=3,
|
||||
params=None,
|
||||
slots_idle=None,
|
||||
slots_processing=None,
|
||||
expected_slots=None):
|
||||
if context.debug:
|
||||
print(f"Starting checking for health for expected_health_status={expected_health_status}")
|
||||
timeout = 3 # seconds
|
||||
if expected_health_status == 'ok':
|
||||
timeout = 10 # CI slow inference
|
||||
print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
|
||||
interval = 0.5
|
||||
counter = 0
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
@ -734,7 +864,7 @@ async def wait_for_health_status(context,
|
|||
health = await health_response.json()
|
||||
if context.debug:
|
||||
print(f"HEALTH - response for expected health status='{expected_health_status}' on "
|
||||
f"'{base_url}/health'?{params} is {health}")
|
||||
f"'{base_url}/health'?{params} is {health}\n")
|
||||
if (status_code == expected_http_status_code
|
||||
and health['status'] == expected_health_status
|
||||
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||
|
@ -757,7 +887,7 @@ async def wait_for_health_status(context,
|
|||
if expected_http_status_code == 503:
|
||||
if len(context.tasks_result) == 0:
|
||||
print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
|
||||
" busy health check missed, probably too fast inference\x1b[0m")
|
||||
" busy health check missed, probably too fast inference\x1b[0m\n")
|
||||
n_completions = await gather_tasks_results(context)
|
||||
if n_completions > 0:
|
||||
return
|
||||
|
@ -791,6 +921,11 @@ def assert_slots_status(slots, expected_slots):
|
|||
f" = {expected[key]} != {slot[key]}")
|
||||
|
||||
|
||||
async def completions_seed(context):
|
||||
return context.seed if hasattr(context, 'seed') and context.seed is not None \
|
||||
else context.server_seed if hasattr(context, 'server_seed') else None
|
||||
|
||||
|
||||
def start_server_background(context):
|
||||
context.server_path = '../../../build/bin/server'
|
||||
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
|
||||
|
@ -800,27 +935,35 @@ def start_server_background(context):
|
|||
'--port', context.server_port,
|
||||
'--model', context.model_file
|
||||
]
|
||||
if context.n_batch:
|
||||
server_args.extend(['--batch-size', context.n_batch])
|
||||
if context.n_gpu_layer:
|
||||
server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
|
||||
if context.server_continuous_batching:
|
||||
server_args.append('--cont-batching')
|
||||
if context.server_embeddings:
|
||||
server_args.append('--embedding')
|
||||
if context.server_metrics:
|
||||
server_args.append('--metrics')
|
||||
if context.model_alias is not None:
|
||||
if context.model_alias:
|
||||
server_args.extend(['--alias', context.model_alias])
|
||||
if context.n_ctx is not None:
|
||||
if context.n_ctx:
|
||||
server_args.extend(['--ctx-size', context.n_ctx])
|
||||
if context.n_slots is not None:
|
||||
if context.n_slots:
|
||||
server_args.extend(['--parallel', context.n_slots])
|
||||
if context.n_server_predict is not None:
|
||||
if context.n_server_predict:
|
||||
server_args.extend(['--n-predict', context.n_server_predict])
|
||||
if context.server_api_key is not None:
|
||||
if context.server_api_key:
|
||||
server_args.extend(['--api-key', context.server_api_key])
|
||||
if context.n_ga:
|
||||
server_args.extend(['--grp-attn-n', context.n_ga])
|
||||
if context.n_ga_w:
|
||||
server_args.extend(['--grp-attn-w', context.n_ga_w])
|
||||
if context.debug:
|
||||
server_args.append('--verbose')
|
||||
if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
|
||||
server_args.extend(['--log-format', "text"])
|
||||
print(f"starting server with: {context.server_path}", *server_args)
|
||||
print(f"starting server with: {context.server_path} {server_args}\n")
|
||||
context.server_process = subprocess.Popen(
|
||||
[str(arg) for arg in [context.server_path, *server_args]],
|
||||
close_fds=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue