convert : fix python 3.8 support, modernize type annotations (#2916)

* convert : fix python 3.8 support

* convert : sort imports

* convert : fix required parameters in convert-llama-ggmlv3-to-gguf

* convert : fix mypy errors in convert-llama-ggmlv3-to-gguf

* convert : use PEP 585 generics and PEP 604 unions

Now that we have `from __future__ import annotations`, we can use this
modern syntax in Python 3.7 instead of restricting support to Python 3.9
or 3.10 respectively.

* gguf.py : a tuple is already a tuple

* add mypy.ini

* convert : add necessary `type: ignore` comments

* gguf-py: bump version
This commit is contained in:
Cebtenzzre 2023-08-31 01:02:23 -04:00 committed by GitHub
parent 8afe228000
commit 92d0b751a7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 193 additions and 168 deletions

View file

@ -1,16 +1,18 @@
#!/usr/bin/env python3
import shutil
import sys
import struct
import tempfile
import numpy as np
from __future__ import annotations
import json
import os
from pathlib import Path
import shutil
import struct
import sys
import tempfile
from enum import IntEnum, auto
from io import BufferedWriter
from typing import Any, BinaryIO, Callable, IO, Dict, List, Optional, Sequence, Tuple, Union
from pathlib import Path
from typing import IO, Any, BinaryIO, Callable, Sequence
import numpy as np
#
# constants
@ -103,7 +105,7 @@ class MODEL_TENSOR(IntEnum):
FFN_NORM : int = auto()
MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.GPT2: "gpt2",
@ -112,7 +114,7 @@ MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
MODEL_ARCH.MPT: "mpt",
}
MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
MODEL_ARCH.LLAMA: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
@ -158,7 +160,7 @@ MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
}
# tensors that will not be serialized
MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
@ -167,7 +169,7 @@ MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
class TensorNameMap:
mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
@ -203,7 +205,7 @@ class TensorNameMap:
),
}
block_mappings_cfg: Dict[MODEL_TENSOR, Tuple[str, ...]] = {
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
@ -298,9 +300,9 @@ class TensorNameMap:
),
}
mapping: Dict[str, Tuple[MODEL_TENSOR, str]]
mapping: dict[str, tuple[MODEL_TENSOR, str]]
tensor_names: Dict[MODEL_TENSOR, str]
tensor_names: dict[MODEL_TENSOR, str]
def __init__(self, arch: MODEL_ARCH, n_blocks: int):
mapping = self.mapping = {}
@ -321,7 +323,7 @@ class TensorNameMap:
key = key.format(bid = bid)
mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[Tuple[MODEL_TENSOR, str]]:
def get_type_and_name(self, key: str, try_suffixes: Sequence[str]) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)
if result is not None:
return result
@ -332,13 +334,13 @@ class TensorNameMap:
return (result[0], result[1] + suffix)
return None
def get_name(self, key: str, try_suffixes: Sequence[str]) -> Optional[str]:
def get_name(self, key: str, try_suffixes: Sequence[str]) -> str | None:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
return result[1]
def get_type(self, key: str, try_suffixes: Sequence[str]) -> Optional[MODEL_TENSOR]:
def get_type(self, key: str, try_suffixes: Sequence[str]) -> MODEL_TENSOR | None:
result = self.get_type_and_name(key, try_suffixes = try_suffixes)
if result is None:
return None
@ -432,10 +434,10 @@ class GGUFWriter:
ti_data = b""
ti_data_count = 0
use_temp_file: bool
temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None
tensors: List[Tuple[np.ndarray[Any, Any], int]]
temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
tensors: list[tuple[np.ndarray[Any, Any], int]]
def __init__(self, path: Union[os.PathLike[str], str], arch: str, use_temp_file = True):
def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file = True):
self.fout = open(path, "wb")
self.arch = arch
self.add_architecture()
@ -531,7 +533,7 @@ class GGUFWriter:
GGUFValueType.FLOAT64: "<d",
GGUFValueType.BOOL: "?" ,
}
def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True):
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True):
if vtype is None:
vtype = GGUFValueType.get_type(val)
@ -561,7 +563,7 @@ class GGUFWriter:
def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n
def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: Union[np.dtype[np.float16], np.dtype[np.float32]], tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None):
def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None):
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
encoded_name = name.encode("utf8")
@ -580,7 +582,7 @@ class GGUFWriter:
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
self.ti_data_count += 1
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Optional[Sequence[int]] = None, raw_dtype: Optional[GGMLQuantizationType] = None):
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None):
if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
fp.seek(0)
@ -600,7 +602,7 @@ class GGUFWriter:
if pad != 0:
self.temp_file.write(bytes([0] * pad))
def write_padding(self, fp: BinaryIO, n: int, align: Optional[int] = None):
def write_padding(self, fp: BinaryIO, n: int, align: int | None = None):
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0:
fp.write(bytes([0] * pad))
@ -726,13 +728,13 @@ class GGUFWriter:
def add_tokenizer_model(self, model: str):
self.add_string(KEY_TOKENIZER_MODEL, model)
def add_token_list(self, tokens: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
self.add_array(KEY_TOKENIZER_LIST, tokens)
def add_token_merges(self, merges: Union[Sequence[str], Sequence[bytes], Sequence[bytearray]]):
def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]):
self.add_array(KEY_TOKENIZER_MERGES, merges)
def add_token_types(self, types: Union[Sequence[TokenType], Sequence[int]]):
def add_token_types(self, types: Sequence[TokenType] | Sequence[int]):
self.add_array(KEY_TOKENIZER_TOKEN_TYPE, types)
def add_token_scores(self, scores: Sequence[float]):
@ -756,11 +758,11 @@ class GGUFWriter:
class SpecialVocab:
load_merges: bool = False
merges: List[str] = []
special_token_types: Tuple[str, ...] = tuple(('bos', 'eos', 'unk', 'sep', 'pad'))
special_token_ids: Dict[str, int] = {}
merges: list[str] = []
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
special_token_ids: dict[str, int] = {}
def __init__(self, path: Path, load_merges: bool = False, special_token_types: Optional[Tuple[str, ...]] = None):
def __init__(self, path: Path, load_merges: bool = False, special_token_types: tuple[str, ...] | None = None):
self.special_token_ids = {}
self.load_merges = load_merges
if special_token_types is not None:
@ -821,7 +823,7 @@ class SpecialVocab:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges)
for typ, tokid in self.special_token_ids.items():
handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None)
handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
if handler is None:
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
continue

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "gguf"
version = "0.2.1"
version = "0.3.1"
description = "Write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [