Compare commits
10 commits
master
...
compilade/
Author | SHA1 | Date | |
---|---|---|---|
|
86ccd30983 | ||
|
6ec70c93be | ||
|
6f215f1f0d | ||
|
0caf60a79e | ||
|
872aecbf30 | ||
|
60c39aca43 | ||
|
959c057bd9 | ||
|
71b50a148c | ||
|
fbf4a85868 | ||
|
e29fd9634c |
33 changed files with 297 additions and 173 deletions
|
@ -89,6 +89,22 @@ let
|
|||
ps.tiktoken
|
||||
ps.torchWithoutCuda
|
||||
ps.transformers
|
||||
|
||||
# server bench
|
||||
ps.matplotlib
|
||||
|
||||
# server tests
|
||||
ps.openai
|
||||
ps.behave
|
||||
ps.prometheus-client
|
||||
|
||||
# for examples/pydantic-models-to-grammar-examples.py
|
||||
ps.docstring-parser
|
||||
ps.pydantic
|
||||
|
||||
# for scripts/compare-llama-bench.py
|
||||
ps.gitpython
|
||||
ps.tabulate
|
||||
]
|
||||
);
|
||||
|
||||
|
|
38
.github/workflows/python-type-check.yml
vendored
Normal file
38
.github/workflows/python-type-check.yml
vendored
Normal file
|
@ -0,0 +1,38 @@
|
|||
name: Python Type-Check
|
||||
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/python-type-check.yml'
|
||||
- '**.py'
|
||||
- '**/requirements*.txt'
|
||||
pull_request:
|
||||
paths:
|
||||
- '.github/workflows/python-type-check.yml'
|
||||
- '**.py'
|
||||
- '**/requirements*.txt'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
python-type-check:
|
||||
runs-on: ubuntu-latest
|
||||
name: pyright type-check
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Python dependencies
|
||||
# TODO: use a venv
|
||||
run: pip install -r requirements/requirements-all.txt
|
||||
- name: Type-check with Pyright
|
||||
uses: jakebailey/pyright-action@v2
|
||||
with:
|
||||
version: 1.1.370
|
||||
level: warning
|
||||
warnings: true
|
|
@ -265,7 +265,7 @@ class Model:
|
|||
break
|
||||
|
||||
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
|
||||
data: np.ndarray = data # type hint
|
||||
data: np.ndarray # type hint
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
data_qtype: gguf.GGMLQuantizationType | None = None
|
||||
|
@ -599,10 +599,6 @@ class Model:
|
|||
|
||||
tokenizer_path = self.dir_model / 'tokenizer.model'
|
||||
|
||||
tokens: list[bytes] = []
|
||||
scores: list[float] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
|
@ -2120,7 +2116,7 @@ class InternLM2Model(Model):
|
|||
logger.error(f'Error: Missing {tokenizer_path}')
|
||||
sys.exit(1)
|
||||
|
||||
sentencepiece_model = model.ModelProto()
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
|
||||
|
||||
|
@ -2972,7 +2968,7 @@ class T5Model(Model):
|
|||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
sentencepiece_model = model.ModelProto()
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
|
||||
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
|
||||
|
@ -3152,7 +3148,7 @@ class JaisModel(Model):
|
|||
# but Jais's PyTorch model simply precalculates the slope values and places them
|
||||
# in relative_pes.slopes
|
||||
n_head_closest_log2 = 2 ** math.floor(math.log2(self.hparams["n_head"]))
|
||||
first_val = float(data_torch._data[0])
|
||||
first_val = float(data_torch[0].item())
|
||||
self.max_alibi_bias = -round(math.log2(first_val) * n_head_closest_log2)
|
||||
|
||||
return tensors
|
||||
|
@ -3186,7 +3182,7 @@ class ChatGLMModel(Model):
|
|||
def set_vocab_chatglm3(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
tokens: list[bytearray] = []
|
||||
tokens: list[bytes] = []
|
||||
toktypes: list[int] = []
|
||||
scores: list[float] = []
|
||||
|
||||
|
@ -3335,7 +3331,7 @@ class ChatGLMModel(Model):
|
|||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
|
||||
self.gguf_writer.add_name(self.hparams["_name_or_path"].split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
|
||||
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
|
||||
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
|
||||
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
|
||||
|
|
|
@ -354,7 +354,8 @@ class GGMLToGGUF:
|
|||
|
||||
|
||||
def handle_metadata(cfg, hp):
|
||||
import convert
|
||||
import examples.convert_legacy_llama as convert
|
||||
|
||||
assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory'
|
||||
hf_config_path = cfg.model_metadata_dir / "config.json"
|
||||
orig_config_path = cfg.model_metadata_dir / "params.json"
|
||||
|
|
|
@ -353,7 +353,7 @@ class Metadata:
|
|||
version: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
licence: Optional[str] = None
|
||||
license: Optional[str] = None
|
||||
source_url: Optional[str] = None
|
||||
source_hf_repo: Optional[str] = None
|
||||
|
||||
|
@ -492,12 +492,13 @@ class LazyTensor:
|
|||
|
||||
LazyModel: TypeAlias = 'dict[str, LazyTensor]'
|
||||
|
||||
ModelFormat: TypeAlias = Literal['ggml', 'torch', 'safetensors', 'none']
|
||||
|
||||
@dataclass
|
||||
class ModelPlus:
|
||||
model: LazyModel
|
||||
paths: list[Path] # Where this was read from.
|
||||
format: Literal['ggml', 'torch', 'safetensors', 'none']
|
||||
format: ModelFormat
|
||||
vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab.
|
||||
|
||||
|
||||
|
@ -536,7 +537,7 @@ def merge_sharded(models: list[LazyModel]) -> LazyModel:
|
|||
|
||||
|
||||
def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
|
||||
formats = set(mp.format for mp in models_plus)
|
||||
formats: set[ModelFormat] = set(mp.format for mp in models_plus)
|
||||
assert len(formats) == 1, "different formats?"
|
||||
format = formats.pop()
|
||||
paths = [path for mp in models_plus for path in mp.paths]
|
||||
|
@ -555,7 +556,7 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
|
|||
else:
|
||||
model = merge_sharded([mp.model for mp in models_plus])
|
||||
|
||||
return ModelPlus(model, paths, format, vocab) # pytype: disable=wrong-arg-types
|
||||
return ModelPlus(model, paths, format, vocab)
|
||||
|
||||
|
||||
def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor:
|
||||
|
@ -805,7 +806,7 @@ class OutputFile:
|
|||
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
|
||||
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
||||
|
||||
def add_meta_model(self, params: Params, metadata: Metadata) -> None:
|
||||
def add_meta_model(self, params: Params, metadata: Metadata | None) -> None:
|
||||
# Metadata About The Model And Its Provenence
|
||||
name = "LLaMA"
|
||||
if metadata is not None and metadata.name is not None:
|
||||
|
@ -827,8 +828,8 @@ class OutputFile:
|
|||
self.gguf.add_url(metadata.url)
|
||||
if metadata.description is not None:
|
||||
self.gguf.add_description(metadata.description)
|
||||
if metadata.licence is not None:
|
||||
self.gguf.add_licence(metadata.licence)
|
||||
if metadata.license is not None:
|
||||
self.gguf.add_licence(metadata.license)
|
||||
if metadata.source_url is not None:
|
||||
self.gguf.add_source_url(metadata.source_url)
|
||||
if metadata.source_hf_repo is not None:
|
||||
|
@ -943,7 +944,7 @@ class OutputFile:
|
|||
@staticmethod
|
||||
def write_vocab_only(
|
||||
fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab,
|
||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata = None,
|
||||
endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: Metadata | None = None,
|
||||
) -> None:
|
||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||
|
||||
|
@ -977,7 +978,7 @@ class OutputFile:
|
|||
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
|
||||
concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE,
|
||||
pad_vocab: bool = False,
|
||||
metadata: Metadata = None,
|
||||
metadata: Metadata | None = None,
|
||||
) -> None:
|
||||
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
|
||||
|
||||
|
@ -1396,6 +1397,8 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab:
|
||||
vocab = model_plus.vocab
|
||||
|
||||
assert params is not None
|
||||
|
||||
logger.info(f"Vocab info: {vocab}")
|
||||
logger.info(f"Special vocab info: {special_vocab}")
|
||||
model = model_plus.model
|
||||
|
|
|
@ -74,7 +74,7 @@ class Tensor:
|
|||
if len(self.ne) == 0:
|
||||
self.nbytes = 0
|
||||
else:
|
||||
self.nbytes = int(np.product(self.ne)) * 4
|
||||
self.nbytes = int(np.prod(self.ne)) * 4
|
||||
else:
|
||||
raise ValueError(f"Unhandled data type '{self.dtype}'")
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
#! pip install pydantic
|
||||
#! python json_schema_pydantic_example.py
|
||||
|
||||
from pydantic import BaseModel, Extra, TypeAdapter
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from annotated_types import MinLen
|
||||
from typing import Annotated, List, Optional
|
||||
import json, requests
|
||||
|
@ -17,6 +17,9 @@ if True:
|
|||
|
||||
The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
|
||||
'''
|
||||
response_format = None
|
||||
type_adapter = None
|
||||
|
||||
if response_model:
|
||||
type_adapter = TypeAdapter(response_model)
|
||||
schema = type_adapter.json_schema()
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
|
@ -188,7 +190,7 @@ def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], ou
|
|||
raise RuntimeError("At least one of min_value or max_value must be set")
|
||||
|
||||
class BuiltinRule:
|
||||
def __init__(self, content: str, deps: list = None):
|
||||
def __init__(self, content: str, deps: list | None = None):
|
||||
self.content = content
|
||||
self.deps = deps or []
|
||||
|
||||
|
@ -248,7 +250,7 @@ class SchemaConverter:
|
|||
|
||||
def _format_literal(self, literal):
|
||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
|
||||
|
@ -403,11 +405,11 @@ class SchemaConverter:
|
|||
i = 0
|
||||
length = len(pattern)
|
||||
|
||||
def to_rule(s: Tuple[str, bool]) -> str:
|
||||
def to_rule(s: tuple[str, bool]) -> str:
|
||||
(txt, is_literal) = s
|
||||
return "\"" + txt + "\"" if is_literal else txt
|
||||
|
||||
def transform() -> Tuple[str, bool]:
|
||||
def transform() -> tuple[str, bool]:
|
||||
'''
|
||||
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
||||
'''
|
||||
|
@ -420,7 +422,7 @@ class SchemaConverter:
|
|||
# We only need a flat structure here to apply repetition operators to the last item, and
|
||||
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||
# (GBNF's syntax is luckily very close to regular expressions!)
|
||||
seq: list[Tuple[str, bool]] = []
|
||||
seq: list[tuple[str, bool]] = []
|
||||
|
||||
def get_dot():
|
||||
if self._dotall:
|
||||
|
|
|
@ -185,6 +185,8 @@ else:
|
|||
fout.add_description("two-tower CLIP model")
|
||||
|
||||
if has_text_encoder:
|
||||
assert t_hparams is not None
|
||||
assert tokens is not None
|
||||
# text_model hparams
|
||||
fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"])
|
||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"])
|
||||
|
@ -259,8 +261,8 @@ if has_vision_encoder:
|
|||
|
||||
|
||||
if processor is not None:
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean
|
||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std
|
||||
image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue]
|
||||
image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
image_mean = args.image_mean if args.image_mean is not None else default_image_mean
|
||||
image_std = args.image_std if args.image_std is not None else default_image_std
|
||||
|
@ -272,7 +274,7 @@ fout.add_bool("clip.use_gelu", use_gelu)
|
|||
|
||||
|
||||
if has_llava_projector:
|
||||
model.vision_model.encoder.layers.pop(-1)
|
||||
model.vision_model.encoder.layers.pop(-1) # pyright: ignore[reportAttributeAccessIssue]
|
||||
projector = torch.load(args.llava_projector)
|
||||
for name, data in projector.items():
|
||||
name = get_tensor_name(name)
|
||||
|
@ -286,7 +288,7 @@ if has_llava_projector:
|
|||
|
||||
print("Projector tensors added\n")
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue]
|
||||
for name, data in state_dict.items():
|
||||
if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector):
|
||||
# we don't need this
|
||||
|
|
|
@ -2,7 +2,9 @@ import argparse
|
|||
import glob
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load as safe_load, save as safe_save, safe_open, save_file
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
from typing import Any, ContextManager, cast
|
||||
|
||||
# Function to determine if file is a SafeTensor file
|
||||
def is_safetensor_file(file_path):
|
||||
|
@ -13,7 +15,7 @@ def is_safetensor_file(file_path):
|
|||
def load_model(file_path):
|
||||
if is_safetensor_file(file_path):
|
||||
tensors = {}
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f:
|
||||
for key in f.keys():
|
||||
tensors[key] = f.get_tensor(key).clone()
|
||||
# output shape
|
||||
|
@ -134,7 +136,7 @@ if len(mm_tensors) == 0:
|
|||
if last_checkpoint is not None:
|
||||
for k, v in last_checkpoint.items():
|
||||
print(k)
|
||||
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint)} tensors.")
|
||||
print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.")
|
||||
print("No tensors found. Is this a LLaVA model?")
|
||||
exit()
|
||||
|
||||
|
@ -143,8 +145,10 @@ print(f"Found additional {len(first_mm_tensors)} tensors to extract.")
|
|||
# projector = {name: checkpoint.[name].float() for name in mm_tensors}
|
||||
projector = {}
|
||||
for name in mm_tensors:
|
||||
assert last_checkpoint is not None
|
||||
projector[name] = last_checkpoint[name].float()
|
||||
for name in first_mm_tensors:
|
||||
assert first_checkpoint is not None
|
||||
projector[name] = first_checkpoint[name].float()
|
||||
|
||||
if len(projector) > 0:
|
||||
|
|
|
@ -6,10 +6,10 @@ import re
|
|||
from copy import copy
|
||||
from enum import Enum
|
||||
from inspect import getdoc, isclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import GenericAlias
|
||||
|
@ -17,6 +17,9 @@ else:
|
|||
# python 3.8 compat
|
||||
from typing import _GenericAlias as GenericAlias
|
||||
|
||||
# TODO: fix this
|
||||
# pyright: reportAttributeAccessIssue=information
|
||||
|
||||
|
||||
class PydanticDataType(Enum):
|
||||
"""
|
||||
|
@ -234,8 +237,9 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None
|
|||
|
||||
# Define the integer part rule
|
||||
integer_part_rule = (
|
||||
"integer-part" + (f"-max{max_digit}" if max_digit is not None else "") + (
|
||||
f"-min{min_digit}" if min_digit is not None else "")
|
||||
"integer-part"
|
||||
+ (f"-max{max_digit}" if max_digit is not None else "")
|
||||
+ (f"-min{min_digit}" if min_digit is not None else "")
|
||||
)
|
||||
|
||||
# Define the fractional part rule based on precision constraints
|
||||
|
@ -458,7 +462,7 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas
|
|||
if not issubclass(model, BaseModel):
|
||||
# For non-Pydantic classes, generate model_fields from __annotations__ or __init__
|
||||
if hasattr(model, "__annotations__") and model.__annotations__:
|
||||
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()}
|
||||
model_fields = {name: (typ, ...) for name, typ in model.__annotations__.items()} # pyright: ignore[reportGeneralTypeIssues]
|
||||
else:
|
||||
init_signature = inspect.signature(model.__init__)
|
||||
parameters = init_signature.parameters
|
||||
|
@ -680,7 +684,7 @@ def generate_markdown_documentation(
|
|||
str: Generated text documentation.
|
||||
"""
|
||||
documentation = ""
|
||||
pyd_models = [(model, True) for model in pydantic_models]
|
||||
pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
|
||||
for model, add_prefix in pyd_models:
|
||||
if add_prefix:
|
||||
documentation += f"{model_prefix}: {model.__name__}\n"
|
||||
|
@ -700,7 +704,7 @@ def generate_markdown_documentation(
|
|||
# Indenting the fields section
|
||||
documentation += f" {fields_prefix}:\n"
|
||||
else:
|
||||
documentation += f" Fields:\n"
|
||||
documentation += f" Fields:\n" # noqa: F541
|
||||
if isclass(model) and issubclass(model, BaseModel):
|
||||
for name, field_type in model.__annotations__.items():
|
||||
# if name == "markdown_code_block":
|
||||
|
@ -778,7 +782,7 @@ def generate_field_markdown(
|
|||
return field_text
|
||||
|
||||
if field_description != "":
|
||||
field_text += f" Description: " + field_description + "\n"
|
||||
field_text += f" Description: {field_description}\n"
|
||||
|
||||
# Check for and include field-specific examples if available
|
||||
if hasattr(model, "Config") and hasattr(model.Config,
|
||||
|
@ -833,7 +837,7 @@ def generate_text_documentation(
|
|||
str: Generated text documentation.
|
||||
"""
|
||||
documentation = ""
|
||||
pyd_models = [(model, True) for model in pydantic_models]
|
||||
pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models]
|
||||
for model, add_prefix in pyd_models:
|
||||
if add_prefix:
|
||||
documentation += f"{model_prefix}: {model.__name__}\n"
|
||||
|
@ -1164,7 +1168,7 @@ def create_dynamic_model_from_function(func: Callable[..., Any]):
|
|||
dynamic_fields[param.name] = (
|
||||
param.annotation if param.annotation != inspect.Parameter.empty else str, default_value)
|
||||
# Creating the dynamic model
|
||||
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) # type: ignore[call-overload]
|
||||
dynamic_model = create_model(f"{func.__name__}", **dynamic_fields)
|
||||
|
||||
for name, param_doc in param_docs:
|
||||
dynamic_model.model_fields[name].description = param_doc.description
|
||||
|
@ -1228,9 +1232,6 @@ def map_grammar_names_to_pydantic_model_class(pydantic_model_list):
|
|||
return output
|
||||
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
def json_schema_to_python_types(schema):
|
||||
type_map = {
|
||||
"any": Any,
|
||||
|
@ -1275,7 +1276,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
|
|||
if items != {}:
|
||||
array = {"properties": items}
|
||||
array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items")
|
||||
fields[field_name] = (List[array_type], ...) # type: ignore[valid-type]
|
||||
fields[field_name] = (List[array_type], ...)
|
||||
else:
|
||||
fields[field_name] = (list, ...)
|
||||
elif field_type == "object":
|
||||
|
@ -1285,7 +1286,8 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
|
|||
required = field_data.get("enum", [])
|
||||
for key, field in fields.items():
|
||||
if key not in required:
|
||||
fields[key] = (Optional[fields[key][0]], ...)
|
||||
optional_type = fields[key][0]
|
||||
fields[key] = (Optional[optional_type], ...)
|
||||
else:
|
||||
field_type = json_schema_to_python_types(field_type)
|
||||
fields[field_name] = (field_type, ...)
|
||||
|
@ -1305,6 +1307,7 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name:
|
|||
required = dictionary.get("required", [])
|
||||
for key, field in fields.items():
|
||||
if key not in required:
|
||||
fields[key] = (Optional[fields[key][0]], ...)
|
||||
optional_type = fields[key][0]
|
||||
fields[key] = (Optional[optional_type], ...)
|
||||
custom_model = create_model(model_name, **fields)
|
||||
return custom_model
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Function calling example using pydantic models.
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
@ -215,9 +216,9 @@ for call in json_data:
|
|||
if call["function"] == "Calculator":
|
||||
print(Calculator(**call["params"]).run())
|
||||
elif call["function"] == "get_current_datetime":
|
||||
print(current_datetime_model(**call["params"]).run())
|
||||
print(current_datetime_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue]
|
||||
elif call["function"] == "get_current_weather":
|
||||
print(current_weather_tool_model(**call["params"]).run())
|
||||
print(current_weather_tool_model(**call["params"]).run()) # pyright: ignore[reportAttributeAccessIssue]
|
||||
# Should output something like this:
|
||||
# 2024-01-14 13:36:06
|
||||
# {"location": "London", "temperature": "42", "unit": "celsius"}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
@ -59,10 +61,11 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
sys.exit(1)
|
||||
|
||||
# start the benchmark
|
||||
iterations = 0
|
||||
data = {}
|
||||
try:
|
||||
start_benchmark(args)
|
||||
|
||||
iterations = 0
|
||||
with open("results.github.env", 'w') as github_env:
|
||||
# parse output
|
||||
with open('k6-results.json', 'r') as bench_results:
|
||||
|
@ -129,7 +132,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
timestamps, metric_values = zip(*values)
|
||||
metric_values = [float(value) for value in metric_values]
|
||||
prometheus_metrics[metric] = metric_values
|
||||
timestamps_dt = [datetime.fromtimestamp(int(ts)) for ts in timestamps]
|
||||
timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
|
||||
plt.figure(figsize=(16, 10), dpi=80)
|
||||
plt.plot(timestamps_dt, metric_values, label=metric)
|
||||
plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
|
||||
|
@ -156,7 +159,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
plt.close()
|
||||
|
||||
# Mermaid format in case images upload failed
|
||||
with (open(f"{metric}.mermaid", 'w') as mermaid_f):
|
||||
with open(f"{metric}.mermaid", 'w') as mermaid_f:
|
||||
mermaid = (
|
||||
f"""---
|
||||
config:
|
||||
|
@ -278,7 +281,7 @@ def start_server_background(args):
|
|||
}
|
||||
server_process = subprocess.Popen(
|
||||
args,
|
||||
**pkwargs)
|
||||
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
|
||||
|
||||
def server_log(in_stream, out_stream):
|
||||
for line in iter(in_stream.readline, b''):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import asyncio
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
@ -8,19 +7,23 @@ import subprocess
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import closing
|
||||
from re import RegexFlag
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import openai
|
||||
from behave import step
|
||||
from openai.types.chat import ChatCompletionChunk
|
||||
from behave import step # pyright: ignore[reportAttributeAccessIssue]
|
||||
from behave.api.async_step import async_run_until_complete
|
||||
from prometheus_client import parser
|
||||
|
||||
# pyright: reportRedeclaration=false
|
||||
|
||||
@step("a server listening on {server_fqdn}:{server_port}")
|
||||
def step_server_config(context, server_fqdn, server_port):
|
||||
def step_server_config(context, server_fqdn: str, server_port: str):
|
||||
context.server_fqdn = server_fqdn
|
||||
context.server_port = int(server_port)
|
||||
context.n_threads = None
|
||||
|
@ -74,34 +77,34 @@ def step_server_config(context, server_fqdn, server_port):
|
|||
|
||||
|
||||
@step('a model file {hf_file} from HF repo {hf_repo}')
|
||||
def step_download_hf_model(context, hf_file, hf_repo):
|
||||
def step_download_hf_model(context, hf_file: str, hf_repo: str):
|
||||
context.model_hf_repo = hf_repo
|
||||
context.model_hf_file = hf_file
|
||||
context.model_file = os.path.basename(hf_file)
|
||||
|
||||
|
||||
@step('a model file {model_file}')
|
||||
def step_model_file(context, model_file):
|
||||
def step_model_file(context, model_file: str):
|
||||
context.model_file = model_file
|
||||
|
||||
|
||||
@step('a model url {model_url}')
|
||||
def step_model_url(context, model_url):
|
||||
def step_model_url(context, model_url: str):
|
||||
context.model_url = model_url
|
||||
|
||||
|
||||
@step('a model alias {model_alias}')
|
||||
def step_model_alias(context, model_alias):
|
||||
def step_model_alias(context, model_alias: str):
|
||||
context.model_alias = model_alias
|
||||
|
||||
|
||||
@step('{seed:d} as server seed')
|
||||
def step_seed(context, seed):
|
||||
def step_seed(context, seed: int):
|
||||
context.server_seed = seed
|
||||
|
||||
|
||||
@step('{ngl:d} GPU offloaded layers')
|
||||
def step_n_gpu_layer(context, ngl):
|
||||
def step_n_gpu_layer(context, ngl: int):
|
||||
if 'N_GPU_LAYERS' in os.environ:
|
||||
new_ngl = int(os.environ['N_GPU_LAYERS'])
|
||||
if context.debug:
|
||||
|
@ -111,37 +114,37 @@ def step_n_gpu_layer(context, ngl):
|
|||
|
||||
|
||||
@step('{n_threads:d} threads')
|
||||
def step_n_threads(context, n_threads):
|
||||
def step_n_threads(context, n_threads: int):
|
||||
context.n_thread = n_threads
|
||||
|
||||
|
||||
@step('{draft:d} as draft')
|
||||
def step_draft(context, draft):
|
||||
def step_draft(context, draft: int):
|
||||
context.draft = draft
|
||||
|
||||
|
||||
@step('{n_ctx:d} KV cache size')
|
||||
def step_n_ctx(context, n_ctx):
|
||||
def step_n_ctx(context, n_ctx: int):
|
||||
context.n_ctx = n_ctx
|
||||
|
||||
|
||||
@step('{n_slots:d} slots')
|
||||
def step_n_slots(context, n_slots):
|
||||
def step_n_slots(context, n_slots: int):
|
||||
context.n_slots = n_slots
|
||||
|
||||
|
||||
@step('{n_predict:d} server max tokens to predict')
|
||||
def step_server_n_predict(context, n_predict):
|
||||
def step_server_n_predict(context, n_predict: int):
|
||||
context.n_server_predict = n_predict
|
||||
|
||||
|
||||
@step('{slot_save_path} as slot save path')
|
||||
def step_slot_save_path(context, slot_save_path):
|
||||
def step_slot_save_path(context, slot_save_path: str):
|
||||
context.slot_save_path = slot_save_path
|
||||
|
||||
|
||||
@step('using slot id {id_slot:d}')
|
||||
def step_id_slot(context, id_slot):
|
||||
def step_id_slot(context, id_slot: int):
|
||||
context.id_slot = id_slot
|
||||
|
||||
|
||||
|
@ -191,7 +194,7 @@ def step_start_server(context):
|
|||
|
||||
@step("the server is {expecting_status}")
|
||||
@async_run_until_complete
|
||||
async def step_wait_for_the_server_to_be_started(context, expecting_status):
|
||||
async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
|
||||
match expecting_status:
|
||||
case 'healthy':
|
||||
await wait_for_health_status(context, context.base_url, 200, 'ok',
|
||||
|
@ -221,7 +224,7 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
|
|||
|
||||
@step('all slots are {expected_slot_status_string}')
|
||||
@async_run_until_complete
|
||||
async def step_all_slots_status(context, expected_slot_status_string):
|
||||
async def step_all_slots_status(context, expected_slot_status_string: Literal['idle', 'busy'] | str):
|
||||
match expected_slot_status_string:
|
||||
case 'idle':
|
||||
expected_slot_status = 0
|
||||
|
@ -237,7 +240,7 @@ async def step_all_slots_status(context, expected_slot_status_string):
|
|||
|
||||
@step('a completion request with {api_error} api error')
|
||||
@async_run_until_complete
|
||||
async def step_request_completion(context, api_error):
|
||||
async def step_request_completion(context, api_error: Literal['raised'] | str):
|
||||
expect_api_error = api_error == 'raised'
|
||||
seeds = await completions_seed(context, num_seeds=1)
|
||||
completion = await request_completion(context.prompts.pop(),
|
||||
|
@ -777,8 +780,8 @@ def step_assert_metric_value(context, metric_name, metric_value):
|
|||
def step_available_models(context):
|
||||
# openai client always expects an api_key
|
||||
openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
|
||||
openai.api_base = f'{context.base_url}/v1'
|
||||
context.models = openai.Model.list().data
|
||||
openai.base_url = f'{context.base_url}/v1/'
|
||||
context.models = openai.models.list().data
|
||||
|
||||
|
||||
@step('{n_model:d} models are supported')
|
||||
|
@ -789,7 +792,7 @@ def step_supported_models(context, n_model):
|
|||
|
||||
|
||||
@step('model {i_model:d} is {param} {preposition} {param_value}')
|
||||
def step_supported_models(context, i_model, param, preposition, param_value):
|
||||
def step_supported_models(context, i_model: int, param: Literal['identified', 'trained'] | str, preposition: str, param_value: str):
|
||||
assert i_model < len(context.models)
|
||||
model = context.models[i_model]
|
||||
|
||||
|
@ -798,7 +801,7 @@ def step_supported_models(context, i_model, param, preposition, param_value):
|
|||
case 'identified':
|
||||
value = model.id
|
||||
case 'trained':
|
||||
value = str(model.meta.n_ctx_train)
|
||||
value = str(model.meta["n_ctx_train"])
|
||||
case _:
|
||||
assert False, "param {param} not supported"
|
||||
assert param_value == value, f"model param {param} {value} != {param_value}"
|
||||
|
@ -810,6 +813,7 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
|
|||
print(f"starting {context.n_prompts} concurrent completion requests...")
|
||||
assert context.n_prompts > 0
|
||||
seeds = await completions_seed(context)
|
||||
assert seeds is not None
|
||||
for prompt_no in range(context.n_prompts):
|
||||
shifted_args = [context.prompts.pop(), seeds[prompt_no], *args]
|
||||
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
|
||||
|
@ -861,7 +865,7 @@ async def request_completion(prompt,
|
|||
id_slot=None,
|
||||
expect_api_error=None,
|
||||
user_api_key=None,
|
||||
temperature=None):
|
||||
temperature=None) -> int | dict[str, Any]:
|
||||
if debug:
|
||||
print(f"Sending completion request: {prompt}")
|
||||
origin = "my.super.domain"
|
||||
|
@ -899,8 +903,8 @@ async def request_completion(prompt,
|
|||
async def oai_chat_completions(user_prompt,
|
||||
seed,
|
||||
system_prompt,
|
||||
base_url,
|
||||
base_path,
|
||||
base_url: str,
|
||||
base_path: str,
|
||||
async_client,
|
||||
debug=False,
|
||||
temperature=None,
|
||||
|
@ -909,7 +913,7 @@ async def oai_chat_completions(user_prompt,
|
|||
enable_streaming=None,
|
||||
response_format=None,
|
||||
user_api_key=None,
|
||||
expect_api_error=None):
|
||||
expect_api_error=None) -> int | dict[str, Any]:
|
||||
if debug:
|
||||
print(f"Sending OAI Chat completions request: {user_prompt}")
|
||||
# openai client always expects an api key
|
||||
|
@ -989,32 +993,35 @@ async def oai_chat_completions(user_prompt,
|
|||
else:
|
||||
try:
|
||||
openai.api_key = user_api_key
|
||||
openai.api_base = f'{base_url}{base_path}'
|
||||
chat_completion = openai.Completion.create(
|
||||
openai.base_url = f'{base_url}{base_path.removesuffix("chat")}'
|
||||
assert model is not None
|
||||
chat_completion = openai.chat.completions.create(
|
||||
messages=payload['messages'],
|
||||
model=model,
|
||||
max_tokens=n_predict,
|
||||
stream=enable_streaming,
|
||||
response_format=payload.get('response_format'),
|
||||
response_format=payload.get('response_format') or openai.NOT_GIVEN,
|
||||
seed=seed,
|
||||
temperature=payload['temperature']
|
||||
)
|
||||
except openai.error.AuthenticationError as e:
|
||||
except openai.AuthenticationError as e:
|
||||
if expect_api_error is not None and expect_api_error:
|
||||
return 401
|
||||
else:
|
||||
assert False, f'error raised: {e}'
|
||||
|
||||
if enable_streaming:
|
||||
chat_completion = cast(openai.Stream[ChatCompletionChunk], chat_completion)
|
||||
for chunk in chat_completion:
|
||||
assert len(chunk.choices) == 1
|
||||
delta = chunk.choices[0].delta
|
||||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
if delta.content is not None:
|
||||
completion_response['content'] += delta.content
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
completion_response['truncated'] = chunk.choices[0].finish_reason != 'stop'
|
||||
else:
|
||||
assert len(chat_completion.choices) == 1
|
||||
assert chat_completion.usage is not None
|
||||
completion_response = {
|
||||
'content': chat_completion.choices[0].message.content,
|
||||
'timings': {
|
||||
|
@ -1028,7 +1035,7 @@ async def oai_chat_completions(user_prompt,
|
|||
return completion_response
|
||||
|
||||
|
||||
async def request_embedding(content, seed, base_url=None):
|
||||
async def request_embedding(content, seed, base_url=None) -> list[list[float]]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{base_url}/embedding',
|
||||
json={
|
||||
|
@ -1041,7 +1048,7 @@ async def request_embedding(content, seed, base_url=None):
|
|||
|
||||
async def request_oai_embeddings(input, seed,
|
||||
base_url=None, user_api_key=None,
|
||||
model=None, async_client=False):
|
||||
model=None, async_client=False) -> list[list[float]]:
|
||||
# openai client always expects an api_key
|
||||
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||
if async_client:
|
||||
|
@ -1063,7 +1070,7 @@ async def request_oai_embeddings(input, seed,
|
|||
response_json = await response.json()
|
||||
assert response_json['model'] == model, f"invalid model received: {response_json['model']}"
|
||||
assert response_json['object'] == 'list'
|
||||
if isinstance(input, collections.abc.Sequence):
|
||||
if isinstance(input, Sequence):
|
||||
embeddings = []
|
||||
for an_oai_embeddings in response_json['data']:
|
||||
embeddings.append(an_oai_embeddings['embedding'])
|
||||
|
@ -1072,19 +1079,14 @@ async def request_oai_embeddings(input, seed,
|
|||
return embeddings
|
||||
else:
|
||||
openai.api_key = user_api_key
|
||||
openai.api_base = f'{base_url}/v1'
|
||||
oai_embeddings = openai.Embedding.create(
|
||||
openai.base_url = f'{base_url}/v1/'
|
||||
assert model is not None
|
||||
oai_embeddings = openai.embeddings.create(
|
||||
model=model,
|
||||
input=input,
|
||||
)
|
||||
|
||||
if isinstance(input, collections.abc.Sequence):
|
||||
embeddings = []
|
||||
for an_oai_embeddings in oai_embeddings.data:
|
||||
embeddings.append(an_oai_embeddings.embedding)
|
||||
else:
|
||||
embeddings = [oai_embeddings.data.embedding]
|
||||
return embeddings
|
||||
return [e.embedding for e in oai_embeddings.data]
|
||||
|
||||
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||
|
@ -1343,7 +1345,7 @@ def start_server_background(context):
|
|||
}
|
||||
context.server_process = subprocess.Popen(
|
||||
[str(arg) for arg in [context.server_path, *server_args]],
|
||||
**pkwargs)
|
||||
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
|
||||
|
||||
def server_log(in_stream, out_stream):
|
||||
for line in iter(in_stream.readline, b''):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
aiohttp~=3.9.3
|
||||
behave~=1.2.6
|
||||
huggingface_hub~=0.20.3
|
||||
numpy~=1.24.4
|
||||
openai~=0.25.0
|
||||
numpy~=1.26.4
|
||||
openai~=1.30.3
|
||||
prometheus-client~=0.20.0
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import asyncio
|
||||
import asyncio.threads
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
|
||||
n = 8
|
||||
|
||||
result = []
|
||||
|
||||
async def requests_post_async(*args, **kwargs):
|
||||
return await asyncio.to_thread(requests.post, *args, **kwargs)
|
||||
return await asyncio.threads.to_thread(requests.post, *args, **kwargs)
|
||||
|
||||
async def main():
|
||||
model_url = "http://127.0.0.1:6900"
|
||||
|
|
|
@ -66,7 +66,7 @@ class Tensor:
|
|||
if len(self.ne) == 0:
|
||||
self.nbytes = 0
|
||||
else:
|
||||
self.nbytes = int(np.product(self.ne)) * 4
|
||||
self.nbytes = int(np.prod(self.ne)) * 4
|
||||
else:
|
||||
raise ValueError(f"Unhandled data type '{self.dtype}'")
|
||||
|
||||
|
|
|
@ -99,6 +99,8 @@ async def main():
|
|||
|
||||
tasks = []
|
||||
|
||||
base_dict = {"FLOAT_TYPE": "float"}
|
||||
|
||||
for fp16 in (False, True):
|
||||
# MUL_MAT
|
||||
matmul_shaders(tasks, fp16, False)
|
||||
|
@ -106,8 +108,6 @@ async def main():
|
|||
matmul_shaders(tasks, fp16, True)
|
||||
|
||||
for tname in type_names:
|
||||
base_dict = {"FLOAT_TYPE": "float"}
|
||||
|
||||
# mul mat vec
|
||||
data_a_key = f"DATA_A_{tname.upper()}"
|
||||
shader = f"mul_mat_vec_{tname}.comp" if tname.endswith("_k") else "mul_mat_vec.comp"
|
||||
|
|
|
@ -67,7 +67,7 @@ class ReaderTensor(NamedTuple):
|
|||
|
||||
class GGUFReader:
|
||||
# I - same as host, S - swapped
|
||||
byte_order: Literal['I'] | Literal['S'] = 'I'
|
||||
byte_order: Literal['I', 'S'] = 'I'
|
||||
alignment: int = GGUF_DEFAULT_ALIGNMENT
|
||||
data_offset: int
|
||||
|
||||
|
@ -86,7 +86,7 @@ class GGUFReader:
|
|||
GGUFValueType.BOOL: np.bool_,
|
||||
}
|
||||
|
||||
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r'] | Literal['r+'] | Literal['c'] = 'r'):
|
||||
def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'):
|
||||
self.data = np.memmap(path, mode = mode)
|
||||
offs = 0
|
||||
|
||||
|
@ -140,7 +140,7 @@ class GGUFReader:
|
|||
return self.tensors[idx]
|
||||
|
||||
def _get(
|
||||
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I'] | Literal['S'] | Literal['<'] = None,
|
||||
self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None,
|
||||
) -> npt.NDArray[Any]:
|
||||
count = int(count)
|
||||
itemsize = int(np.empty([], dtype = dtype).itemsize)
|
||||
|
|
|
@ -16,16 +16,16 @@ logger = logging.getLogger(__name__)
|
|||
class LazyMeta(ABCMeta):
|
||||
|
||||
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
meta_attr = getattr(self._meta, __name)
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
meta_attr = getattr(self._meta, name)
|
||||
if callable(meta_attr):
|
||||
return type(self)._wrap_fn(
|
||||
(lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
|
||||
(lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
|
||||
use_self=self,
|
||||
)
|
||||
elif isinstance(meta_attr, self._tensor_type):
|
||||
# e.g. self.T with torch.Tensor should still be wrapped
|
||||
return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
|
||||
return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
|
||||
else:
|
||||
# no need to wrap non-tensor properties,
|
||||
# and they likely don't depend on the actual contents of the tensor
|
||||
|
@ -141,19 +141,21 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|||
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
|
||||
|
||||
if isinstance(res, cls._tensor_type):
|
||||
def collect_replace(t: LazyBase):
|
||||
if collect_replace.shared_lazy is None:
|
||||
collect_replace.shared_lazy = t._lazy
|
||||
else:
|
||||
collect_replace.shared_lazy.extend(t._lazy)
|
||||
t._lazy = collect_replace.shared_lazy
|
||||
|
||||
class CollectSharedLazy:
|
||||
# emulating a static variable
|
||||
collect_replace.shared_lazy = None
|
||||
shared_lazy: None | deque[LazyBase] = None
|
||||
|
||||
LazyBase._recurse_apply(args, collect_replace)
|
||||
@staticmethod
|
||||
def collect_replace(t: LazyBase):
|
||||
if CollectSharedLazy.shared_lazy is None:
|
||||
CollectSharedLazy.shared_lazy = t._lazy
|
||||
else:
|
||||
CollectSharedLazy.shared_lazy.extend(t._lazy)
|
||||
t._lazy = CollectSharedLazy.shared_lazy
|
||||
|
||||
shared_lazy = collect_replace.shared_lazy
|
||||
LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
|
||||
|
||||
shared_lazy = CollectSharedLazy.shared_lazy
|
||||
|
||||
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
|
||||
else:
|
||||
|
@ -184,6 +186,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|||
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
|
||||
lt._data = lt._func(lt._args)
|
||||
# sanity check
|
||||
assert lt._data is not None
|
||||
assert lt._data.dtype == lt._meta.dtype
|
||||
assert lt._data.shape == lt._meta.shape
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
# pyright: reportUnusedImport=false
|
||||
|
||||
from .gguf_convert_endian import main as gguf_convert_endian_entrypoint
|
||||
from .gguf_dump import main as gguf_dump_entrypoint
|
||||
from .gguf_set_metadata import main as gguf_set_metadata_entrypoint
|
||||
|
|
|
@ -63,9 +63,9 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar) -> None:
|
|||
bar.update(sum_weights_in_tensor)
|
||||
|
||||
sha1_layer = hashlib.sha1()
|
||||
sha1_layer.update(tensor.data)
|
||||
sha1.update(tensor.data)
|
||||
uuidv5_sha1.update(tensor.data)
|
||||
sha1_layer.update(tensor.data.data)
|
||||
sha1.update(tensor.data.data)
|
||||
uuidv5_sha1.update(tensor.data.data)
|
||||
print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100
|
||||
|
||||
# Flush Hash Progress Bar
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import gguf # noqa: F401
|
||||
import gguf # noqa: F401 # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# TODO: add tests
|
||||
|
||||
|
|
|
@ -1,3 +1,21 @@
|
|||
{
|
||||
"extraPaths": ["gguf-py"],
|
||||
}
|
||||
"pythonVersion": "3.9",
|
||||
"pythonPlatform": "All",
|
||||
"reportUnusedImport": "warning",
|
||||
"reportDuplicateImport": "error",
|
||||
"reportDeprecated": "warning",
|
||||
"reportUnnecessaryTypeIgnoreComment": "warning",
|
||||
"executionEnvironments": [
|
||||
{
|
||||
// TODO: make this version override work correctly
|
||||
"root": "gguf-py",
|
||||
"pythonVersion": "3.8",
|
||||
},
|
||||
{
|
||||
// uses match expressions in steps.py
|
||||
"root": "examples/server/tests",
|
||||
"pythonVersion": "3.10",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
|
12
requirements/requirements-all.txt
Normal file
12
requirements/requirements-all.txt
Normal file
|
@ -0,0 +1,12 @@
|
|||
-r ../examples/llava/requirements.txt
|
||||
-r ../examples/server/bench/requirements.txt
|
||||
-r ../examples/server/tests/requirements.txt
|
||||
|
||||
-r ./requirements-compare-llama-bench.txt
|
||||
-r ./requirements-pydantic.txt
|
||||
-r ./requirements-test-tokenizer-random.txt
|
||||
|
||||
-r ./requirements-convert_hf_to_gguf.txt
|
||||
-r ./requirements-convert_hf_to_gguf_update.txt
|
||||
-r ./requirements-convert_legacy_llama.txt
|
||||
-r ./requirements-convert_llama_ggml_to_gguf.txt
|
2
requirements/requirements-compare-llama-bench.txt
Normal file
2
requirements/requirements-compare-llama-bench.txt
Normal file
|
@ -0,0 +1,2 @@
|
|||
tabulate~=0.9.0
|
||||
GitPython~=3.1.43
|
2
requirements/requirements-pydantic.txt
Normal file
2
requirements/requirements-pydantic.txt
Normal file
|
@ -0,0 +1,2 @@
|
|||
docstring_parser~=0.15
|
||||
pydantic~=2.6.3
|
1
requirements/requirements-test-tokenizer-random.txt
Normal file
1
requirements/requirements-test-tokenizer-random.txt
Normal file
|
@ -0,0 +1 @@
|
|||
cffi~=1.16.0
|
|
@ -108,6 +108,11 @@ check_convert_script() {
|
|||
fatal "$py missing requirements. Expected: $reqs"
|
||||
fi
|
||||
|
||||
# Check that all sub-requirements are added to top-level requirements.txt
|
||||
if ! grep -qF "$reqs" requirements.txt; then
|
||||
fatal "$reqs needs to be added to requirements.txt"
|
||||
fi
|
||||
|
||||
local venv="$workdir/$pyname-venv"
|
||||
python3 -m venv "$venv"
|
||||
|
||||
|
@ -134,12 +139,7 @@ EOF
|
|||
|
||||
readonly ignore_eq_eq='check_requirements: ignore "=="'
|
||||
|
||||
for req in "$reqs_dir"/*; do
|
||||
# Check that all sub-requirements are added to top-level requirements.txt
|
||||
if ! grep -qF "$req" requirements.txt; then
|
||||
fatal "$req needs to be added to requirements.txt"
|
||||
fi
|
||||
|
||||
for req in */**/requirements*.txt; do
|
||||
# Make sure exact release versions aren't being pinned in the requirements
|
||||
# Filters out the ignore string
|
||||
if grep -vF "$ignore_eq_eq" "$req" | grep -q '=='; then
|
||||
|
|
|
@ -123,13 +123,13 @@ builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall()
|
|||
|
||||
try:
|
||||
repo = git.Repo(".", search_parent_directories=True)
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
except git.InvalidGitRepositoryError:
|
||||
repo = None
|
||||
|
||||
|
||||
def find_parent_in_data(commit):
|
||||
def find_parent_in_data(commit: git.Commit):
|
||||
"""Helper function to find the most recent parent measured in number of commits for which there is data."""
|
||||
heap = [(0, commit)]
|
||||
heap: list[tuple[int, git.Commit]] = [(0, commit)]
|
||||
seen_hexsha8 = set()
|
||||
while heap:
|
||||
depth, current_commit = heapq.heappop(heap)
|
||||
|
@ -144,7 +144,7 @@ def find_parent_in_data(commit):
|
|||
return None
|
||||
|
||||
|
||||
def get_all_parent_hexsha8s(commit):
|
||||
def get_all_parent_hexsha8s(commit: git.Commit):
|
||||
"""Helper function to recursively get hexsha8 values for all parents of a commit."""
|
||||
unvisited = [commit]
|
||||
visited = []
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import array
|
||||
import unicodedata
|
||||
import requests
|
||||
|
@ -133,7 +135,7 @@ table_nfd.sort()
|
|||
|
||||
|
||||
# group ranges with same flags
|
||||
ranges_flags = [(0, codepoint_flags[0])] # start, flags
|
||||
ranges_flags: list[tuple[int, int]] = [(0, codepoint_flags[0])] # start, flags
|
||||
for codepoint, flags in enumerate(codepoint_flags):
|
||||
if flags != ranges_flags[-1][1]:
|
||||
ranges_flags.append((codepoint, flags))
|
||||
|
@ -141,11 +143,11 @@ ranges_flags.append((MAX_CODEPOINTS, 0x0000))
|
|||
|
||||
|
||||
# group ranges with same nfd
|
||||
ranges_nfd = [(0, 0, 0)] # start, last, nfd
|
||||
ranges_nfd: list[tuple[int, int, int]] = [(0, 0, 0)] # start, last, nfd
|
||||
for codepoint, norm in table_nfd:
|
||||
start = ranges_nfd[-1][0]
|
||||
if ranges_nfd[-1] != (start, codepoint - 1, norm):
|
||||
ranges_nfd.append(None)
|
||||
ranges_nfd.append(None) # type: ignore[arg-type] # dummy, will be replaced below
|
||||
start = codepoint
|
||||
ranges_nfd[-1] = (start, codepoint, norm)
|
||||
|
||||
|
@ -179,13 +181,13 @@ for codepoint in table_whitespace:
|
|||
out("};\n")
|
||||
|
||||
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
|
||||
for tuple in table_lowercase:
|
||||
out("{0x%06X, 0x%06X}," % tuple)
|
||||
for tuple_lw in table_lowercase:
|
||||
out("{0x%06X, 0x%06X}," % tuple_lw)
|
||||
out("};\n")
|
||||
|
||||
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
|
||||
for tuple in table_uppercase:
|
||||
out("{0x%06X, 0x%06X}," % tuple)
|
||||
for tuple_up in table_uppercase:
|
||||
out("{0x%06X, 0x%06X}," % tuple_up)
|
||||
out("};\n")
|
||||
|
||||
out("const std::vector<range_nfd> unicode_ranges_nfd = { // start, last, nfd")
|
||||
|
|
|
@ -6,6 +6,8 @@
|
|||
# python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import logging
|
||||
import argparse
|
||||
|
@ -13,7 +15,9 @@ import subprocess
|
|||
import random
|
||||
import unicodedata
|
||||
|
||||
from typing import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator, cast
|
||||
from typing_extensions import Buffer
|
||||
|
||||
import cffi
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -28,15 +32,15 @@ class LibLlama:
|
|||
DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"]
|
||||
DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
|
||||
|
||||
def __init__(self, path_llama_h: str = None, path_includes: list[str] = [], path_libllama: str = None):
|
||||
def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None):
|
||||
path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
|
||||
path_includes = path_includes or self.DEFAULT_PATH_INCLUDES
|
||||
path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
|
||||
(self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama)
|
||||
self.lib.llama_backend_init()
|
||||
|
||||
def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str):
|
||||
cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
|
||||
def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str) -> tuple[cffi.FFI, Any]:
|
||||
cmd = ["gcc", "-O0", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="]
|
||||
cmd += ["-I" + path for path in path_includes] + [path_llama_h]
|
||||
res = subprocess.run(cmd, stdout=subprocess.PIPE)
|
||||
assert (res.returncode == 0)
|
||||
|
@ -68,7 +72,7 @@ class LibLlama:
|
|||
class LibLlamaModel:
|
||||
|
||||
def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
|
||||
self.lib = libllama.lib
|
||||
self.lib: Any = libllama.lib
|
||||
self.ffi = libllama.ffi
|
||||
if isinstance(mparams, dict):
|
||||
mparams = libllama.model_default_params(**mparams)
|
||||
|
@ -94,11 +98,11 @@ class LibLlamaModel:
|
|||
self.lib = None
|
||||
|
||||
def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]:
|
||||
text = text.encode("utf-8")
|
||||
num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
|
||||
encoded_text: bytes = text.encode("utf-8")
|
||||
num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
|
||||
while num < 0 and len(self.token_ids) < (16 << 20):
|
||||
self.token_ids = self.ffi.new("llama_token[]", -2 * num)
|
||||
num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, len(self.token_ids), add_special, parse_special)
|
||||
num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special)
|
||||
return list(self.token_ids[0:num])
|
||||
|
||||
def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str:
|
||||
|
@ -110,7 +114,7 @@ class LibLlamaModel:
|
|||
while num < 0 and len(self.text_buff) < (16 << 20):
|
||||
self.text_buff = self.ffi.new("uint8_t[]", -2 * num)
|
||||
num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special)
|
||||
return str(self.ffi.buffer(self.text_buff, num), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||
return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD'
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
|
@ -152,7 +156,7 @@ class TokenizerGroundtruth (Tokenizer):
|
|||
|
||||
class TokenizerLlamaCpp (Tokenizer):
|
||||
|
||||
libllama: LibLlama = None
|
||||
libllama: LibLlama | None = None
|
||||
|
||||
def __init__(self, vocab_file: str):
|
||||
if not self.libllama:
|
||||
|
@ -404,7 +408,7 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100
|
|||
|
||||
def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]):
|
||||
|
||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||
def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str):
|
||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||
if a != b:
|
||||
return i
|
||||
|
@ -433,7 +437,7 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
|||
decode_errors = 0
|
||||
MAX_ERRORS = 10
|
||||
|
||||
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||
logger.info("%s: %s" % (generator.__qualname__, "ini"))
|
||||
for text in generator:
|
||||
# print(repr(text), text.encode())
|
||||
# print(repr(text), hex(ord(text[0])), text.encode())
|
||||
|
@ -472,13 +476,13 @@ def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLl
|
|||
break
|
||||
|
||||
t_total = time.perf_counter() - t_start
|
||||
logger.info(f"{generator.__name__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
|
||||
logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}")
|
||||
|
||||
|
||||
def main(argv: list[str] = None):
|
||||
def main(argv: list[str] | None = None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
|
||||
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
||||
parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file")
|
||||
parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
|
@ -520,7 +524,7 @@ if __name__ == "__main__":
|
|||
format = "%(levelname)s %(message)s",
|
||||
)
|
||||
|
||||
path_tokenizers = "./models/tokenizers/"
|
||||
path_tokenizers = Path("./models/tokenizers/")
|
||||
path_vocab_format = "./models/ggml-vocab-%s.gguf"
|
||||
|
||||
tokenizers = [
|
||||
|
@ -556,6 +560,6 @@ if __name__ == "__main__":
|
|||
for tokenizer in tokenizers:
|
||||
logger.info("-" * 50)
|
||||
logger.info(f"TOKENIZER: '{tokenizer}'")
|
||||
vocab_file = path_vocab_format % tokenizer
|
||||
dir_tokenizer = path_tokenizers + "/" + tokenizer
|
||||
main([vocab_file, dir_tokenizer, "--verbose"])
|
||||
vocab_file = Path(path_vocab_format % tokenizer)
|
||||
dir_tokenizer = path_tokenizers / tokenizer
|
||||
main([str(vocab_file), str(dir_tokenizer), "--verbose"])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue