server: fix disconnection logic in test (before post response headers)

This commit is contained in:
ochafik 2024-10-05 00:44:13 +01:00
parent 6f693f14b0
commit 52c5a6244f

View file

@ -1012,6 +1012,9 @@ async def request_completion(prompt,
headers['Authorization'] = f'Bearer {user_api_key}'
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
if disconnect_after_millis is not None:
await asyncio.sleep(disconnect_after_millis / 1000.0)
return 0
async with session.post(f'{base_url}/completion',
json={
"input_prefix": prompt_prefix,
@ -1025,10 +1028,6 @@ async def request_completion(prompt,
"n_probs": 2,
},
headers=headers) as response:
if disconnect_after_millis is not None:
await asyncio.sleep(disconnect_after_millis / 1000.0)
return 0
if expect_api_error is None or not expect_api_error:
assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin
@ -1088,13 +1087,12 @@ async def oai_chat_completions(user_prompt,
origin = 'llama.cpp'
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session:
if disconnect_after_millis is not None:
await asyncio.sleep(disconnect_after_millis / 1000.0)
return 0
async with session.post(f'{base_url}{base_path}',
json=payload,
headers=headers) as response:
if disconnect_after_millis is not None:
await asyncio.sleep(disconnect_after_millis / 1000.0)
return 0
if enable_streaming:
assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin