Cleanups for gguf-py

This commit is contained in:
KerfuffleV2 2023-08-27 10:58:13 -06:00
parent 1793f25cfa
commit 084dd216cd
3 changed files with 54 additions and 53 deletions

View file

@ -6,7 +6,7 @@ import tempfile
import numpy as np import numpy as np
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Any, IO, List, Optional from typing import Any, BinaryIO, IO, Dict, List, Optional, Tuple
# #
# constants # constants
@ -71,35 +71,35 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
class MODEL_ARCH(IntEnum): class MODEL_ARCH(IntEnum):
LLAMA = auto() LLAMA : int = auto()
FALCON = auto() FALCON : int = auto()
GPT2 = auto() GPT2 : int = auto()
GPTJ = auto() GPTJ : int = auto()
GPTNEOX = auto() GPTNEOX: int = auto()
MPT = auto() MPT : int = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto() TOKEN_EMBD : int = auto()
POS_EMBD = auto() POS_EMBD : int = auto()
OUTPUT = auto() OUTPUT : int = auto()
OUTPUT_NORM = auto() OUTPUT_NORM : int = auto()
ROPE_FREQS = auto() ROPE_FREQS : int = auto()
ATTN_Q = auto() ATTN_Q : int = auto()
ATTN_K = auto() ATTN_K : int = auto()
ATTN_V = auto() ATTN_V : int = auto()
ATTN_QKV = auto() ATTN_QKV : int = auto()
ATTN_OUT = auto() ATTN_OUT : int = auto()
ATTN_NORM = auto() ATTN_NORM : int = auto()
ATTN_NORM_2 = auto() ATTN_NORM_2 : int = auto()
ATTN_ROT_EMBD = auto() ATTN_ROT_EMBD: int = auto()
FFN_GATE = auto() FFN_GATE : int = auto()
FFN_DOWN = auto() FFN_DOWN : int = auto()
FFN_UP = auto() FFN_UP : int = auto()
FFN_NORM = auto() FFN_NORM : int = auto()
MODEL_ARCH_NAMES = { MODEL_ARCH_NAMES: Dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.GPT2: "gpt2", MODEL_ARCH.GPT2: "gpt2",
@ -108,7 +108,7 @@ MODEL_ARCH_NAMES = {
MODEL_ARCH.MPT: "mpt", MODEL_ARCH.MPT: "mpt",
} }
MODEL_TENSOR_NAMES = { MODEL_TENSOR_NAMES: Dict[MODEL_ARCH, Dict[MODEL_TENSOR, str]] = {
MODEL_ARCH.LLAMA: { MODEL_ARCH.LLAMA: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm", MODEL_TENSOR.OUTPUT_NORM: "output_norm",
@ -154,7 +154,7 @@ MODEL_TENSOR_NAMES = {
} }
# tensors that will not be serialized # tensors that will not be serialized
MODEL_TENSOR_SKIP = { MODEL_TENSOR_SKIP: Dict[MODEL_ARCH, List[MODEL_TENSOR]] = {
MODEL_ARCH.LLAMA: [ MODEL_ARCH.LLAMA: [
MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.ATTN_ROT_EMBD,
@ -388,15 +388,21 @@ class GGUFValueType(IntEnum):
class GGUFWriter: class GGUFWriter:
fout: BinaryIO
arch: str
offset_tensor = 0
data_alignment = GGUF_DEFAULT_ALIGNMENT
kv_data = b""
kv_data_count = 0
ti_data = b""
ti_data_count = 0
use_temp_file: bool
temp_file: Optional[tempfile.SpooledTemporaryFile[bytes]] = None
tensors: List[Tuple[np.ndarray, int]]
def __init__(self, path: str, arch: str, use_temp_file = True): def __init__(self, path: str, arch: str, use_temp_file = True):
self.fout = open(path, "wb") self.fout = open(path, "wb")
self.arch = arch self.arch = arch
self.offset_tensor = 0
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.kv_data = b""
self.kv_data_count = 0
self.ti_data = b""
self.ti_data_count = 0
self.add_architecture() self.add_architecture()
self.use_temp_file = use_temp_file self.use_temp_file = use_temp_file
self.tensors = [] self.tensors = []
@ -477,7 +483,7 @@ class GGUFWriter:
self.add_key(key) self.add_key(key)
self.add_val(val, GGUFValueType.ARRAY) self.add_val(val, GGUFValueType.ARRAY)
def add_val(self: str, val: Any, vtype: GGUFValueType = None, add_vtype: bool = True): def add_val(self, val: Any, vtype: Optional[GGUFValueType] = None, add_vtype: bool = True):
if vtype is None: if vtype is None:
vtype = GGUFValueType.get_type(val) vtype = GGUFValueType.get_type(val)
@ -545,15 +551,16 @@ class GGUFWriter:
self.ti_data_count += 1 self.ti_data_count += 1
def add_tensor(self, name: str, tensor: np.ndarray, raw_shape: Optional[np.ndarray] = None, raw_dtype: Optional[GGMLQuantizationType] = None): def add_tensor(self, name: str, tensor: np.ndarray, raw_shape: Optional[np.ndarray] = None, raw_dtype: Optional[GGMLQuantizationType] = None):
if self.use_temp_file and not hasattr(self, "temp_file"): if self.use_temp_file and self.temp_file is None:
self.temp_file = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024) fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
self.temp_file.seek(0) fp.seek(0)
self.temp_file = fp
self.add_tensor_info(name, raw_shape if raw_shape is not None else tensor.shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) self.add_tensor_info(name, raw_shape if raw_shape is not None else tensor.shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
if not self.use_temp_file: if self.temp_file is None:
self.tensors.append((tensor, pad)) self.tensors.append((tensor, pad))
return return
@ -562,23 +569,20 @@ class GGUFWriter:
if pad != 0: if pad != 0:
self.temp_file.write(bytes([0] * pad)) self.temp_file.write(bytes([0] * pad))
def write_padding(self, fp: BinaryIO, n: int, align: Optional[int] = None):
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray): def write_tensor_data(self, tensor: np.ndarray):
pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell() self.write_padding(self.fout, self.fout.tell())
if pad != 0:
self.fout.write(bytes([0] * pad))
tensor.tofile(self.fout) tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes)
pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes
if pad != 0:
self.fout.write(bytes([0] * pad))
def write_tensors_to_file(self): def write_tensors_to_file(self):
self.write_ti_data_to_file() self.write_ti_data_to_file()
pad = GGUFWriter.ggml_pad(self.fout.tell(), self.data_alignment) - self.fout.tell() self.write_padding(self.fout, self.fout.tell())
if pad != 0:
self.fout.write(bytes([0] * pad))
if not self.use_temp_file: if not self.use_temp_file:
for (currtensor, currpad) in self.tensors: for (currtensor, currpad) in self.tensors:
@ -654,10 +658,6 @@ class GGUFWriter:
self.add_bool( self.add_bool(
KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) KEY_USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
def add_tensor_data_layout(self, layout: str):
self.add_string(
KEY_TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
def add_head_count(self, count: int): def add_head_count(self, count: int):
self.add_uint32( self.add_uint32(
KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count) KEY_ATTENTION_HEAD_COUNT.format(arch=self.arch), count)

0
gguf-py/gguf/py.typed Normal file
View file

View file

@ -5,6 +5,7 @@ description = "Write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"] authors = ["GGML <ggml@ggml.ai>"]
packages = [ packages = [
{include = "gguf"}, {include = "gguf"},
{include = "gguf/py.typed"},
] ]
readme = "README.md" readme = "README.md"
homepage = "https://ggml.ai" homepage = "https://ggml.ai"