py : type-check all Python scripts with Pyright
This commit is contained in:
parent
87e25a1d1b
commit
e29fd9634c
35 changed files with 264 additions and 136 deletions
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
@ -59,10 +61,11 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
sys.exit(1)
|
||||
|
||||
# start the benchmark
|
||||
iterations = 0
|
||||
data = {}
|
||||
try:
|
||||
start_benchmark(args)
|
||||
|
||||
iterations = 0
|
||||
with open("results.github.env", 'w') as github_env:
|
||||
# parse output
|
||||
with open('k6-results.json', 'r') as bench_results:
|
||||
|
@ -129,7 +132,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
timestamps, metric_values = zip(*values)
|
||||
metric_values = [float(value) for value in metric_values]
|
||||
prometheus_metrics[metric] = metric_values
|
||||
timestamps_dt = [datetime.fromtimestamp(int(ts)) for ts in timestamps]
|
||||
timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
|
||||
plt.figure(figsize=(16, 10), dpi=80)
|
||||
plt.plot(timestamps_dt, metric_values, label=metric)
|
||||
plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
|
||||
|
@ -156,7 +159,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
plt.close()
|
||||
|
||||
# Mermaid format in case images upload failed
|
||||
with (open(f"{metric}.mermaid", 'w') as mermaid_f):
|
||||
with open(f"{metric}.mermaid", 'w') as mermaid_f:
|
||||
mermaid = (
|
||||
f"""---
|
||||
config:
|
||||
|
@ -278,7 +281,7 @@ def start_server_background(args):
|
|||
}
|
||||
server_process = subprocess.Popen(
|
||||
args,
|
||||
**pkwargs)
|
||||
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
|
||||
|
||||
def server_log(in_stream, out_stream):
|
||||
for line in iter(in_stream.readline, b''):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
@ -8,16 +7,20 @@ import subprocess
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import closing
|
||||
from re import RegexFlag
|
||||
from typing import cast
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import openai
|
||||
from behave import step
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from behave import step # pyright: ignore[reportAttributeAccessIssue]
|
||||
from behave.api.async_step import async_run_until_complete
|
||||
from prometheus_client import parser
|
||||
|
||||
# pyright: reportRedeclaration=false
|
||||
|
||||
@step("a server listening on {server_fqdn}:{server_port}")
|
||||
def step_server_config(context, server_fqdn, server_port):
|
||||
|
@ -777,8 +780,8 @@ def step_assert_metric_value(context, metric_name, metric_value):
|
|||
def step_available_models(context):
|
||||
# openai client always expects an api_key
|
||||
openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
|
||||
openai.api_base = f'{context.base_url}/v1'
|
||||
context.models = openai.Model.list().data
|
||||
openai.base_url = f'{context.base_url}/v1'
|
||||
context.models = openai.models.list().data
|
||||
|
||||
|
||||
@step('{n_model:d} models are supported')
|
||||
|
@ -810,6 +813,7 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
|
|||
print(f"starting {context.n_prompts} concurrent completion requests...")
|
||||
assert context.n_prompts > 0
|
||||
seeds = await completions_seed(context)
|
||||
assert seeds is not None
|
||||
for prompt_no in range(context.n_prompts):
|
||||
shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
|
||||
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
|
||||
|
@ -989,32 +993,35 @@ async def oai_chat_completions(user_prompt,
|
|||
else:
|
||||
try:
|
||||
openai.api_key = user_api_key
|
||||
openai.api_base = f'{base_url}{base_path}'
|
||||
chat_completion = openai.Completion.create(
|
||||
openai.base_url = f'{base_url}{base_path}'
|
||||
assert model is not None
|
||||
chat_completion = openai.chat.completions.create(
|
||||
messages=payload['messages'],
|
||||
model=model,
|
||||
max_tokens=n_predict,
|
||||
stream=enable_streaming,
|
||||
response_format=payload.get('response_format'),
|
||||
response_format=payload.get('response_format') or openai.NOT_GIVEN,
|
||||
seed=seed,
|
||||
temperature=payload['temperature']
|
||||
)
|
||||
except openai.error.AuthenticationError as e:
|
||||
except openai.AuthenticationError as e:
|
||||
if expect_api_error is not None and expect_api_error:
|
||||
return 401
|
||||
else:
|
||||
assert False, f'error raised: {e}'
|
||||
|
||||
if enable_streaming:
|
||||
chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion)
|
||||
for chunk in chat_completion:
|
||||
assert len(chunk.choices) == 1
|
||||
delta = chunk.choices[0].delta
|
||||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
if delta.content is not None:
|
||||
completion_response['content'] += delta.content
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
|
||||
else:
|
||||
assert len(chat_completion.choices) == 1
|
||||
assert chat_completion.usage is not None
|
||||
completion_response = {
|
||||
'content': chat_completion.choices[0].message.content,
|
||||
'timings': {
|
||||
|
@ -1063,7 +1070,7 @@ async def request_oai_embeddings(input, seed,
|
|||
response_json = await response.json()
|
||||
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
|
||||
assert response_json['object'] == 'list'
|
||||
if isinstance(input, collections.abc.Sequence):
|
||||
if isinstance(input, Sequence):
|
||||
embeddings = []
|
||||
for an_oai_embeddings in response_json['data']:
|
||||
embeddings.append(an_oai_embeddings['embedding'])
|
||||
|
@ -1072,19 +1079,14 @@ async def request_oai_embeddings(input, seed,
|
|||
return embeddings
|
||||
else:
|
||||
openai.api_key = user_api_key
|
||||
openai.api_base = f'{base_url}/v1'
|
||||
oai_embeddings = openai.Embedding.create(
|
||||
openai.base_url = f'{base_url}/v1'
|
||||
assert model is not None
|
||||
oai_embeddings = openai.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
)
|
||||
|
||||
if isinstance(input, collections.abc.Sequence):
|
||||
embeddings = []
|
||||
for an_oai_embeddings in oai_embeddings.data:
|
||||
embeddings.append(an_oai_embeddings.embedding)
|
||||
else:
|
||||
embeddings = [oai_embeddings.data.embedding]
|
||||
return embeddings
|
||||
return oai_embeddings.data
|
||||
|
||||
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||
|
@ -1122,7 +1124,7 @@ def assert_all_predictions_equal(completion_responses):
|
|||
if i == j:
|
||||
continue
|
||||
content_j = response_j['content']
|
||||
assert content_i == content_j, "contents not equal"
|
||||
assert content_i == content_j, "contents not equal"
|
||||
|
||||
|
||||
def assert_all_predictions_different(completion_responses):
|
||||
|
@ -1136,7 +1138,7 @@ def assert_all_predictions_different(completion_responses):
|
|||
if i == j:
|
||||
continue
|
||||
content_j = response_j['content']
|
||||
assert content_i != content_j, "contents not different"
|
||||
assert content_i != content_j, "contents not different"
|
||||
|
||||
|
||||
def assert_all_token_probabilities_equal(completion_responses):
|
||||
|
@ -1153,7 +1155,7 @@ def assert_all_token_probabilities_equal(completion_responses):
|
|||
if i == j:
|
||||
continue
|
||||
probs_j = response_j['completion_probabilities'][pos]['probs']
|
||||
assert probs_i == probs_j, "contents not equal"
|
||||
assert probs_i == probs_j, "contents not equal"
|
||||
|
||||
|
||||
async def gather_tasks_results(context):
|
||||
|
@ -1343,7 +1345,7 @@ def start_server_background(context):
|
|||
}
|
||||
context.server_process = subprocess.Popen(
|
||||
[str(arg) for arg in [context.server_path, *server_args]],
|
||||
**pkwargs)
|
||||
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
|
||||
|
||||
def server_log(in_stream, out_stream):
|
||||
for line in iter(in_stream.readline, b''):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
aiohttp~=3.9.3
|
||||
behave~=1.2.6
|
||||
huggingface_hub~=0.20.3
|
||||
numpy~=1.24.4
|
||||
openai~=0.25.0
|
||||
numpy~=1.26.4
|
||||
openai~=1.30.3
|
||||
prometheus-client~=0.20.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue