py : type-check all Python scripts with Pyright
This commit is contained in:
parent
87e25a1d1b
commit
e29fd9634c
35 changed files with 264 additions and 136 deletions
|
@ -353,7 +353,7 @@ class Metadata:
|
|||
version: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
licence: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
source_hf_repo: Optional[str] = None
|
||||
|
||||
|
@ -492,12 +492,13 @@ class LazyTensor:
|
|||
|
||||
LazyModel: TypeAlias = 'dict[str, LazyTensor]'
|
||||
|
||||
ModelFormat: TypeAlias = Literal['ggml', 'torch', 'safetensors', 'none']
|
||||
|
||||
@dataclass
|
||||
class ModelPlus:
|
||||
model: LazyModel
|
||||
paths: list[Path] # Where this was read from.
|
||||
format: Literal['ggml', 'torch', 'safetensors', 'none']
|
||||
format: ModelFormat
|
||||
vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
|
||||
|
||||
|
||||
|
@ -536,7 +537,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
|
|||
|
||||
|
||||
def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
|
||||
formats = set(mp.format for mp in models_plus)
|
||||
formats: set[ModelFormat] = set(mp.format for mp in models_plus)
|
||||
assert len(formats) == 1, "different formats?"
|
||||
format = formats.pop()
|
||||
paths = [path for mp in models_plus for path in mp.paths]
|
||||
|
@ -555,7 +556,7 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
|
|||
else:
|
||||
model = merge_sharded([mp.model for mp in models_plus])
|
||||
|
||||
return ModelPlus(model, paths, format, vocab) # pytype: disable=wrong-arg-types
|
||||
return ModelPlus(model, paths, format, vocab)
|
||||
|
||||
|
||||
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
|
||||
|
@ -805,7 +806,7 @@ class OutputFile:
|
|||
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
||||
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
||||
|
||||
def add_meta_model(self, params: Params, metadata: Metadata) -> None:
|
||||
def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
|
||||
# Metadata About The Model And Its Provenence
|
||||
name = "LLaMA"
|
||||
if metadata is not None and metadata.name is not None:
|
||||
|
@ -827,8 +828,8 @@ class OutputFile:
|
|||
self.gguf.add_url(metadata.url)
|
||||
if metadata.description is not None:
|
||||
self.gguf.add_description(metadata.description)
|
||||
if metadata.licence is not None:
|
||||
self.gguf.add_licence(metadata.licence)
|
||||
if metadata.license is not None:
|
||||
self.gguf.add_licence(metadata.license)
|
||||
if metadata.source_url is not None:
|
||||
self.gguf.add_source_url(metadata.source_url)
|
||||
if metadata.source_hf_repo is not None:
|
||||
|
@ -943,7 +944,7 @@ class OutputFile:
|
|||
@staticmethod
|
||||
def write_vocab_only(
|
||||
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata = None,
|
||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
|
||||
) -> None:
|
||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||
|
||||
|
@ -977,7 +978,7 @@ class OutputFile:
|
|||
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
||||
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||
pad_vocab: bool = False,
|
||||
metadata: Metadata = None,
|
||||
metadata: Metadata | None = None,
|
||||
) -> None:
|
||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||
|
||||
|
@ -1396,6 +1397,8 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
|
||||
vocab = model_plus.vocab
|
||||
|
||||
assert params is not None
|
||||
|
||||
logger.info(f"Vocab info: {vocab}")
|
||||
logger.info(f"Special vocab info: {special_vocab}")
|
||||
model = model_plus.model
|
||||
|
|
|
@ -74,7 +74,7 @@ class Tensor:
|
|||
if len(self.ne) == 0:
|
||||
self.nbytes = 0
|
||||
else:
|
||||
self.nbytes = int(np.product(self.ne)) * 4
|
||||
self.nbytes = int(np.prod(self.ne)) * 4
|
||||
else:
|
||||
raise ValueError(f"Unhandled data type '{self.dtype}'")
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#! pip install pydantic
|
||||
#! python json_schema_pydantic_example.py
|
||||
|
||||
from pydantic import BaseModel, Extra, TypeAdapter
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from annotated_types import MinLen
|
||||
from typing import Annotated, List, Optional
|
||||
import json, requests
|
||||
|
@ -17,6 +17,9 @@ if True:
|
|||
|
||||
The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
|
||||
'''
|
||||
response_format = None
|
||||
type_adapter = None
|
||||
|
||||
if response_model:
|
||||
type_adapter = TypeAdapter(response_model)
|
||||
schema = type_adapter.json_schema()
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
|
@ -188,7 +190,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
|
|||
raise RuntimeError("At least one of min_value or max_value must be set")
|
||||
|
||||
class BuiltinRule:
|
||||
def __init__(self, content: str, deps: list = None):
|
||||
def __init__(self, content: str, deps: list | None = None):
|
||||
self.content = content
|
||||
self.deps = deps or []
|
||||
|
||||
|
@ -248,7 +250,7 @@ class SchemaConverter:
|
|||
|
||||
def _format_literal(self, literal):
|
||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
|
||||
|
@ -403,11 +405,11 @@ class SchemaConverter:
|
|||
i = 0
|
||||
length = len(pattern)
|
||||
|
||||
def to_rule(s: Tuple[str, bool]) -> str:
|
||||
def to_rule(s: tuple[str, bool]) -> str:
|
||||
(txt, is_literal) = s
|
||||
return "\"" + txt + "\"" if is_literal else txt
|
||||
|
||||
def transform() -> Tuple[str, bool]:
|
||||
def transform() -> tuple[str, bool]:
|
||||
'''
|
||||
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
||||
'''
|
||||
|
@ -420,7 +422,7 @@ class SchemaConverter:
|
|||
# We only need a flat structure here to apply repetition operators to the last item, and
|
||||
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||
# (GBNF's syntax is luckily very close to regular expressions!)
|
||||
seq: list[Tuple[str, bool]] = []
|
||||
seq: list[tuple[str, bool]] = []
|
||||
|
||||
def get_dot():
|
||||
if self._dotall:
|
||||
|
|
|
@ -185,6 +185,8 @@ else:
|
|||
fout.add_description("two-tower CLIP model")
|
||||
|
||||
if has_text_encoder:
|
||||
assert t_hparams is not None
|
||||
assert tokens is not None
|
||||
# text_model hparams
|
||||
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
||||
|
@ -259,8 +261,8 @@ if has_vision_encoder:
|
|||
|
||||
|
||||
if processor is not None:
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
|
||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
|
||||
image_std = args.image_std if args.image_std is not None else default_image_std
|
||||
|
@ -272,7 +274,7 @@ fout.add_bool("clip.use_gelu", use_gelu)
|
|||
|
||||
|
||||
if has_llava_projector:
|
||||
model.vision_model.encoder.layers.pop(-1)
|
||||
model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue]
|
||||
projector = torch.load(args.llava_projector)
|
||||
for name, data in projector.items():
|
||||
name = get_tensor_name(name)
|
||||
|
@ -286,7 +288,7 @@ if has_llava_projector:
|
|||
|
||||
print("Projector tensors added\n")
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue]
|
||||
for name, data in state_dict.items():
|
||||
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
|
||||
# we don't need this
|
||||
|
|
|
@ -2,7 +2,9 @@ import argparse
|
|||
import glob
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load as safe_load, save as safe_save, safe_open, save_file
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
from typing import Any, ContextManager, cast
|
||||
|
||||
# Function to determine if file is a SafeTensor file
|
||||
def is_safetensor_file(file_path):
|
||||
|
@ -13,7 +15,7 @@ def is_safetensor_file(file_path):
|
|||
def load_model(file_path):
|
||||
if is_safetensor_file(file_path):
|
||||
tensors = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key).clone()
|
||||
# output shape
|
||||
|
@ -134,7 +136,7 @@ if len(mm_tensors) == 0:
|
|||
if last_checkpoint is not None:
|
||||
for k, v in last_checkpoint.items():
|
||||
print(k)
|
||||
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
|
||||
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
|
||||
print("No tensors found. Is this a LLaVA model?")
|
||||
exit()
|
||||
|
||||
|
@ -143,8 +145,10 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
|
|||
# projector = {name: checkpoint.[name].float() for name in mm_tensors}
|
||||
projector = {}
|
||||
for name in mm_tensors:
|
||||
assert last_checkpoint is not None
|
||||
projector[name] = last_checkpoint[name].float()
|
||||
for name in first_mm_tensors:
|
||||
assert first_checkpoint is not None
|
||||
projector[name] = first_checkpoint[name].float()
|
||||
|
||||
if len(projector) > 0:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
-r ../../requirements/requirements-convert_legacy_llama.txt
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
pillow~=10.2.0
|
||||
torch~=2.2.1
|
||||
|
|
|
@ -6,10 +6,10 @@ import re
|
|||
from copy import copy
|
||||
from enum import Enum
|
||||
from inspect import getdoc, isclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import GenericAlias
|
||||
|
@ -17,6 +17,9 @@ else:
|
|||
# python 3.8 compat
|
||||
from typing import _GenericAlias as GenericAlias
|
||||
|
||||
# TODO: fix this
|
||||
# pyright: reportAttributeAccessIssue=information
|
||||
|
||||
|
||||
class PydanticDataType(Enum):
|
||||
"""
|
||||
|
@ -234,8 +237,9 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None
|
|||
|
||||
# Define the integer part rule
|
||||
integer_part_rule = (
|
||||
"integer-part" + (f"-max{max_digit}" if max_digit is not None else "") + (
|
||||
f"-min{min_digit}" if min_digit is not None else "")
|
||||
"integer-part"
|
||||
+ (f"-max{max_digit}" if max_digit is not None else "")
|
||||
+ (f"-min{min_digit}" if min_digit is not None else "")
|
||||
)
|
||||
|
||||
# Define the fractional part rule based on precision constraints
|
||||
|
@ -458,7 +462,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
|
|||
if not issubclass(model, BaseModel):
|
||||
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
|
||||
if hasattr(model, "__annotations__") and model.__annotations__:
|
||||
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()}
|
||||
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues]
|
||||
else:
|
||||
init_signature = inspect.signature(model.__init__)
|
||||
parameters = init_signature.parameters
|
||||
|
@ -680,7 +684,7 @@ def generate_markdown_documentation(
|
|||
str: Generated text documentation.
|
||||
"""
|
||||
documentation = ""
|
||||
pyd_models = [(model, True) for model in pydantic_models]
|
||||
pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
|
||||
for model, add_prefix in pyd_models:
|
||||
if add_prefix:
|
||||
documentation += f"{model_prefix}: {model.__name__}\n"
|
||||
|
@ -700,7 +704,7 @@ def generate_markdown_documentation(
|
|||
# Indenting the fields section
|
||||
documentation += f" {fields_prefix}:\n"
|
||||
else:
|
||||
documentation += f" Fields:\n"
|
||||
documentation += f" Fields:\n" # noqa: F541
|
||||
if isclass(model) and issubclass(model, BaseModel):
|
||||
for name, field_type in model.__annotations__.items():
|
||||
# if name == "markdown_code_block":
|
||||
|
@ -778,7 +782,7 @@ def generate_field_markdown(
|
|||
return field_text
|
||||
|
||||
if field_description != "":
|
||||
field_text += f" Description: " + field_description + "\n"
|
||||
field_text += f" Description: {field_description}\n"
|
||||
|
||||
# Check for and include field-specific examples if available
|
||||
if hasattr(model, "Config") and hasattr(model.Config,
|
||||
|
@ -833,7 +837,7 @@ def generate_text_documentation(
|
|||
str: Generated text documentation.
|
||||
"""
|
||||
documentation = ""
|
||||
pyd_models = [(model, True) for model in pydantic_models]
|
||||
pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
|
||||
for model, add_prefix in pyd_models:
|
||||
if add_prefix:
|
||||
documentation += f"{model_prefix}: {model.__name__}\n"
|
||||
|
@ -1164,7 +1168,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
|
|||
dynamic_fields[param.name] = (
|
||||
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
|
||||
# Creating the dynamic model
|
||||
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) # type: ignore[call-overload]
|
||||
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
|
||||
|
||||
for name, param_doc in param_docs:
|
||||
dynamic_model.model_fields[name].description = param_doc.description
|
||||
|
@ -1228,9 +1232,6 @@ def map_grammar_names_to_pydantic_model_class(pydantic_model_list):
|
|||
return output
|
||||
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
def json_schema_to_python_types(schema):
|
||||
type_map = {
|
||||
"any": Any,
|
||||
|
@ -1275,7 +1276,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
|
|||
if items != {}:
|
||||
array = {"properties": items}
|
||||
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
|
||||
fields[field_name] = (List[array_type], ...) # type: ignore[valid-type]
|
||||
fields[field_name] = (List[array_type], ...)
|
||||
else:
|
||||
fields[field_name] = (list, ...)
|
||||
elif field_type == "object":
|
||||
|
@ -1285,7 +1286,8 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
|
|||
required = field_data.get("enum", [])
|
||||
for key, field in fields.items():
|
||||
if key not in required:
|
||||
fields[key] = (Optional[fields[key][0]], ...)
|
||||
optional_type = fields[key][0]
|
||||
fields[key] = (Optional[optional_type], ...)
|
||||
else:
|
||||
field_type = json_schema_to_python_types(field_type)
|
||||
fields[field_name] = (field_type, ...)
|
||||
|
@ -1305,6 +1307,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
|
|||
required = dictionary.get("required", [])
|
||||
for key, field in fields.items():
|
||||
if key not in required:
|
||||
fields[key] = (Optional[fields[key][0]], ...)
|
||||
optional_type = fields[key][0]
|
||||
fields[key] = (Optional[optional_type], ...)
|
||||
custom_model = create_model(model_name, **fields)
|
||||
return custom_model
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Function calling example using pydantic models.
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
@ -215,9 +216,9 @@ for call in json_data:
|
|||
if call["function"] == "Calculator":
|
||||
print(Calculator(**call["params"]).run())
|
||||
elif call["function"] == "get_current_datetime":
|
||||
print(current_datetime_model(**call["params"]).run())
|
||||
print(current_datetime_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue]
|
||||
elif call["function"] == "get_current_weather":
|
||||
print(current_weather_tool_model(**call["params"]).run())
|
||||
print(current_weather_tool_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue]
|
||||
# Should output something like this:
|
||||
# 2024-01-14 13:36:06
|
||||
# {"location": "London", "temperature": "42", "unit": "celsius"}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
@ -59,10 +61,11 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
sys.exit(1)
|
||||
|
||||
# start the benchmark
|
||||
iterations = 0
|
||||
data = {}
|
||||
try:
|
||||
start_benchmark(args)
|
||||
|
||||
iterations = 0
|
||||
with open("results.github.env", 'w') as github_env:
|
||||
# parse output
|
||||
with open('k6-results.json', 'r') as bench_results:
|
||||
|
@ -129,7 +132,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
timestamps, metric_values = zip(*values)
|
||||
metric_values = [float(value) for value in metric_values]
|
||||
prometheus_metrics[metric] = metric_values
|
||||
timestamps_dt = [datetime.fromtimestamp(int(ts)) for ts in timestamps]
|
||||
timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
|
||||
plt.figure(figsize=(16, 10), dpi=80)
|
||||
plt.plot(timestamps_dt, metric_values, label=metric)
|
||||
plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
|
||||
|
@ -156,7 +159,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
plt.close()
|
||||
|
||||
# Mermaid format in case images upload failed
|
||||
with (open(f"{metric}.mermaid", 'w') as mermaid_f):
|
||||
with open(f"{metric}.mermaid", 'w') as mermaid_f:
|
||||
mermaid = (
|
||||
f"""---
|
||||
config:
|
||||
|
@ -278,7 +281,7 @@ def start_server_background(args):
|
|||
}
|
||||
server_process = subprocess.Popen(
|
||||
args,
|
||||
**pkwargs)
|
||||
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
|
||||
|
||||
def server_log(in_stream, out_stream):
|
||||
for line in iter(in_stream.readline, b''):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
@ -8,16 +7,20 @@ import subprocess
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import closing
|
||||
from re import RegexFlag
|
||||
from typing import cast
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import openai
|
||||
from behave import step
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from behave import step # pyright: ignore[reportAttributeAccessIssue]
|
||||
from behave.api.async_step import async_run_until_complete
|
||||
from prometheus_client import parser
|
||||
|
||||
# pyright: reportRedeclaration=false
|
||||
|
||||
@step("a server listening on {server_fqdn}:{server_port}")
|
||||
def step_server_config(context, server_fqdn, server_port):
|
||||
|
@ -777,8 +780,8 @@ def step_assert_metric_value(context, metric_name, metric_value):
|
|||
def step_available_models(context):
|
||||
# openai client always expects an api_key
|
||||
openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
|
||||
openai.api_base = f'{context.base_url}/v1'
|
||||
context.models = openai.Model.list().data
|
||||
openai.base_url = f'{context.base_url}/v1'
|
||||
context.models = openai.models.list().data
|
||||
|
||||
|
||||
@step('{n_model:d} models are supported')
|
||||
|
@ -810,6 +813,7 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
|
|||
print(f"starting {context.n_prompts} concurrent completion requests...")
|
||||
assert context.n_prompts > 0
|
||||
seeds = await completions_seed(context)
|
||||
assert seeds is not None
|
||||
for prompt_no in range(context.n_prompts):
|
||||
shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
|
||||
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
|
||||
|
@ -989,32 +993,35 @@ async def oai_chat_completions(user_prompt,
|
|||
else:
|
||||
try:
|
||||
openai.api_key = user_api_key
|
||||
openai.api_base = f'{base_url}{base_path}'
|
||||
chat_completion = openai.Completion.create(
|
||||
openai.base_url = f'{base_url}{base_path}'
|
||||
assert model is not None
|
||||
chat_completion = openai.chat.completions.create(
|
||||
messages=payload['messages'],
|
||||
model=model,
|
||||
max_tokens=n_predict,
|
||||
stream=enable_streaming,
|
||||
response_format=payload.get('response_format'),
|
||||
response_format=payload.get('response_format') or openai.NOT_GIVEN,
|
||||
seed=seed,
|
||||
temperature=payload['temperature']
|
||||
)
|
||||
except openai.error.AuthenticationError as e:
|
||||
except openai.AuthenticationError as e:
|
||||
if expect_api_error is not None and expect_api_error:
|
||||
return 401
|
||||
else:
|
||||
assert False, f'error raised: {e}'
|
||||
|
||||
if enable_streaming:
|
||||
chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion)
|
||||
for chunk in chat_completion:
|
||||
assert len(chunk.choices) == 1
|
||||
delta = chunk.choices[0].delta
|
||||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
if delta.content is not None:
|
||||
completion_response['content'] += delta.content
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
|
||||
else:
|
||||
assert len(chat_completion.choices) == 1
|
||||
assert chat_completion.usage is not None
|
||||
completion_response = {
|
||||
'content': chat_completion.choices[0].message.content,
|
||||
'timings': {
|
||||
|
@ -1063,7 +1070,7 @@ async def request_oai_embeddings(input, seed,
|
|||
response_json = await response.json()
|
||||
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
|
||||
assert response_json['object'] == 'list'
|
||||
if isinstance(input, collections.abc.Sequence):
|
||||
if isinstance(input, Sequence):
|
||||
embeddings = []
|
||||
for an_oai_embeddings in response_json['data']:
|
||||
embeddings.append(an_oai_embeddings['embedding'])
|
||||
|
@ -1072,19 +1079,14 @@ async def request_oai_embeddings(input, seed,
|
|||
return embeddings
|
||||
else:
|
||||
openai.api_key = user_api_key
|
||||
openai.api_base = f'{base_url}/v1'
|
||||
oai_embeddings = openai.Embedding.create(
|
||||
openai.base_url = f'{base_url}/v1'
|
||||
assert model is not None
|
||||
oai_embeddings = openai.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
)
|
||||
|
||||
if isinstance(input, collections.abc.Sequence):
|
||||
embeddings = []
|
||||
for an_oai_embeddings in oai_embeddings.data:
|
||||
embeddings.append(an_oai_embeddings.embedding)
|
||||
else:
|
||||
embeddings = [oai_embeddings.data.embedding]
|
||||
return embeddings
|
||||
return oai_embeddings.data
|
||||
|
||||
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||
|
@ -1122,7 +1124,7 @@ def assert_all_predictions_equal(completion_responses):
|
|||
if i == j:
|
||||
continue
|
||||
content_j = response_j['content']
|
||||
assert content_i == content_j, "contents not equal"
|
||||
assert content_i == content_j, "contents not equal"
|
||||
|
||||
|
||||
def assert_all_predictions_different(completion_responses):
|
||||
|
@ -1136,7 +1138,7 @@ def assert_all_predictions_different(completion_responses):
|
|||
if i == j:
|
||||
continue
|
||||
content_j = response_j['content']
|
||||
assert content_i != content_j, "contents not different"
|
||||
assert content_i != content_j, "contents not different"
|
||||
|
||||
|
||||
def assert_all_token_probabilities_equal(completion_responses):
|
||||
|
@ -1153,7 +1155,7 @@ def assert_all_token_probabilities_equal(completion_responses):
|
|||
if i == j:
|
||||
continue
|
||||
probs_j = response_j['completion_probabilities'][pos]['probs']
|
||||
assert probs_i == probs_j, "contents not equal"
|
||||
assert probs_i == probs_j, "contents not equal"
|
||||
|
||||
|
||||
async def gather_tasks_results(context):
|
||||
|
@ -1343,7 +1345,7 @@ def start_server_background(context):
|
|||
}
|
||||
context.server_process = subprocess.Popen(
|
||||
[str(arg) for arg in [context.server_path, *server_args]],
|
||||
**pkwargs)
|
||||
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
|
||||
|
||||
def server_log(in_stream, out_stream):
|
||||
for line in iter(in_stream.readline, b''):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
aiohttp~=3.9.3
|
||||
behave~=1.2.6
|
||||
huggingface_hub~=0.20.3
|
||||
numpy~=1.24.4
|
||||
openai~=0.25.0
|
||||
numpy~=1.26.4
|
||||
openai~=1.30.3
|
||||
prometheus-client~=0.20.0
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import asyncio
|
||||
import asyncio.threads
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
|
||||
n = 8
|
||||
|
||||
result = []
|
||||
|
||||
async def requests_post_async(*args, **kwargs):
|
||||
return await asyncio.to_thread(requests.post, *args, **kwargs)
|
||||
return await asyncio.threads.to_thread(requests.post, *args, **kwargs)
|
||||
|
||||
async def main():
|
||||
model_url = "http://127.0.0.1:6900"
|
||||
|
|
|
@ -66,7 +66,7 @@ class Tensor:
|
|||
if len(self.ne) == 0:
|
||||
self.nbytes = 0
|
||||
else:
|
||||
self.nbytes = int(np.product(self.ne)) * 4
|
||||
self.nbytes = int(np.prod(self.ne)) * 4
|
||||
else:
|
||||
raise ValueError(f"Unhandled data type '{self.dtype}'")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue