From 084dd216cd18cd33023853ad83107efebdb5088e Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Sun, 27 Aug 2023 10:58:13 -0600 Subject: [PATCH] Cleanups for gguf-py --- gguf-py/gguf/gguf.py | 106 ++++++++++++++++++++--------------------- gguf-py/gguf/py.typed | 0 gguf-py/pyproject.toml | 1 + 3 files changed, 54 insertions(+), 53 deletions(-) create mode 100644 gguf-py/gguf/py.typed diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 838a2c0f8..8f119c534 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -6,7 +6,7 @@ import tempfile import numpy as np from enum import IntEnum, auto -from typing import Any, IO, List, Optional +from typing import Any, BinaryIO, IO, Dict, List, Optional, Tuple # # constants @@ -71,35 +71,35 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world" class MODEL_ARCH(IntEnum): - LLAMA = auto() - FALCON = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - MPT = auto() + LLAMA : int = auto() + FALCON : int = auto() + GPT2 : int = auto() + GPTJ : int = auto() + GPTNEOX: int = auto() + MPT : int = auto() class MODEL_TENSOR(IntEnum): - TOKEN_EMBD = auto() - POS_EMBD = auto() - OUTPUT = auto() - OUTPUT_NORM = auto() - ROPE_FREQS = auto() - ATTN_Q = auto() - ATTN_K = auto() - ATTN_V = auto() - ATTN_QKV = auto() - ATTN_OUT = auto() - ATTN_NORM = auto() - ATTN_NORM_2 = auto() - ATTN_ROT_EMBD = auto() - FFN_GATE = auto() - FFN_DOWN = auto() - FFN_UP = auto() - FFN_NORM = auto() + TOKEN_EMBD : int = auto() + POS_EMBD : int = auto() + OUTPUT : int = auto() + OUTPUT_NORM : int = auto() + ROPE_FREQS : int = auto() + ATTN_Q : int = auto() + ATTN_K : int = auto() + ATTN_V : int = auto() + ATTN_QKV : int = auto() + ATTN_OUT : int = auto() + ATTN_NORM : int = auto() + ATTN_NORM_2 : int = auto() + ATTN_ROT_EMBD: int = auto() + FFN_GATE : int = auto() + FFN_DOWN : int = auto() + FFN_UP : int = auto() + FFN_NORM : int = auto() -MODEL_ARCH_NAMES = { +MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = { MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.GPT2: "gpt2", @@ -108,7 +108,7 @@ MODEL_ARCH_NAMES = { MODEL_ARCH.MPT: "mpt", } -MODEL_TENSOR_NAMES = { +MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = { MODEL_ARCH.LLAMA: { MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.OUTPUT_NORM: "output_norm", @@ -154,7 +154,7 @@ MODEL_TENSOR_NAMES = { } # tensors that will not be serialized -MODEL_TENSOR_SKIP = { +MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = { MODEL_ARCH.LLAMA: [ MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, @@ -388,15 +388,21 @@ class GGUFValueType(IntEnum): class GGUFWriter: + fout: BinaryIO + arch: str + offset_tensor = 0 + data_alignment = GGUF_DEFAULT_ALIGNMENT + kv_data = b"" + kv_data_count = 0 + ti_data = b"" + ti_data_count = 0 + use_temp_file: bool + temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None + tensors: List[Tuple[np.ndarray, int]] + def __init__(self, path: str, arch: str, use_temp_file = True): self.fout = open(path, "wb") self.arch = arch - self.offset_tensor = 0 - self.data_alignment = GGUF_DEFAULT_ALIGNMENT - self.kv_data = b"" - self.kv_data_count = 0 - self.ti_data = b"" - self.ti_data_count = 0 self.add_architecture() self.use_temp_file = use_temp_file self.tensors = [] @@ -477,7 +483,7 @@ class GGUFWriter: self.add_key(key) self.add_val(val, GGUFValueType.ARRAY) - def add_val(self: str, val: Any, vtype: GGUFValueType = None, add_vtype: bool = True): + def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True): if vtype is None: vtype = GGUFValueType.get_type(val) @@ -545,15 +551,16 @@ class GGUFWriter: self.ti_data_count += 1 def add_tensor(self, name: str, tensor: np.ndarray, raw_shape: Optional[np.ndarray] = None, raw_dtype: Optional[GGMLQuantizationType] = None): - if self.use_temp_file and not hasattr(self, "temp_file"): - self.temp_file = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024) - self.temp_file.seek(0) + if self.use_temp_file and self.temp_file is None: + fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024) + fp.seek(0) + self.temp_file = fp self.add_tensor_info(name, raw_shape if raw_shape is not None else tensor.shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes - if not self.use_temp_file: + if self.temp_file is None: self.tensors.append((tensor, pad)) return @@ -562,23 +569,20 @@ class GGUFWriter: if pad != 0: self.temp_file.write(bytes([0] * pad)) + def write_padding(self, fp: BinaryIO, n: int, align: Optional[int] = None): + pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n + if pad != 0: + fp.write(bytes([0] * pad)) + def write_tensor_data(self, tensor: np.ndarray): - pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell() - if pad != 0: - self.fout.write(bytes([0] * pad)) - + self.write_padding(self.fout, self.fout.tell()) tensor.tofile(self.fout) - - pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes - if pad != 0: - self.fout.write(bytes([0] * pad)) + self.write_padding(self.fout, tensor.nbytes) def write_tensors_to_file(self): self.write_ti_data_to_file() - pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell() - if pad != 0: - self.fout.write(bytes([0] * pad)) + self.write_padding(self.fout, self.fout.tell()) if not self.use_temp_file: for (currtensor, currpad) in self.tensors: @@ -654,10 +658,6 @@ class GGUFWriter: self.add_bool( KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) - def add_tensor_data_layout(self, layout: str): - self.add_string( - KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout) - def add_head_count(self, count: int): self.add_uint32( KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count) diff --git a/gguf-py/gguf/py.typed b/gguf-py/gguf/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index cc70e28b7..c66b069f9 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -5,6 +5,7 @@ description = "Write ML models in GGUF for GGML" authors = ["GGML "] packages = [ {include = "gguf"}, + {include = "gguf/py.typed"}, ] readme = "README.md" homepage = "https://ggml.ai"