fix typing and clean up

This commit is contained in:
Christian Zhou-Zheng 2024-06-09 16:02:23 -04:00
parent f7ecd99691
commit 9d7f694438
2 changed files with 39 additions and 62 deletions

View file

@ -66,7 +66,7 @@ class Model:
model_arch: gguf.MODEL_ARCH model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = 0, small_first_shard: bool = 0): model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
if type(self) is Model: if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model self.dir_model = dir_model
@ -2875,6 +2875,10 @@ def main() -> None:
"auto": gguf.LlamaFileType.GUESSED, "auto": gguf.LlamaFileType.GUESSED,
} }
if args.use_temp_file and (args.split_max_tensors > 0 or args.split_max_size != "0"):
logger.error("Error: Cannot use temp file when splitting")
sys.exit(1)
if args.outfile is not None: if args.outfile is not None:
fname_out = args.outfile fname_out = args.outfile
else: else:

View file

@ -9,7 +9,7 @@ from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path from pathlib import Path
from io import BufferedWriter from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping, TypeAlias from typing import IO, Any, Sequence, Mapping
from string import ascii_letters, digits from string import ascii_letters, digits
import numpy as np import numpy as np
@ -33,11 +33,6 @@ logger = logging.getLogger(__name__)
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
NUM_SHARD_KV_DATA = 3
METADATA_ONLY_INDICATOR = -1
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType | None]] # {key: (value, type)}
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType | None] # (tensor name, tensor data, tensor dtype)
@dataclass @dataclass
@ -65,7 +60,7 @@ class WriterState(Enum):
class GGUFWriter: class GGUFWriter:
fout: list[BufferedWriter] | None fout: list[BufferedWriter] | None
path: os.PathLike[str] | str | None path: Path | None
temp_file: tempfile.SpooledTemporaryFile[bytes] | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[dict[str, TensorInfo]] tensors: list[dict[str, TensorInfo]]
kv_data: list[dict[str, GGUFValue]] kv_data: list[dict[str, GGUFValue]]
@ -88,15 +83,15 @@ class GGUFWriter:
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
): ):
self.fout = [] self.fout = None
self.path = path self.path = Path(path) if path else None
self.arch = arch self.arch = arch
self.endianess = endianess self.endianess = endianess
self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.use_temp_file = use_temp_file self.use_temp_file = use_temp_file
self.temp_file = None self.temp_file = None
self.tensors = [] self.tensors = [{}]
self.kv_data = [dict()] self.kv_data = [{}]
self.split_max_tensors = split_max_tensors self.split_max_tensors = split_max_tensors
self.split_max_size = split_max_size self.split_max_size = split_max_size
self.dry_run = dry_run self.dry_run = dry_run
@ -107,30 +102,16 @@ class GGUFWriter:
self.state = WriterState.NO_FILE self.state = WriterState.NO_FILE
if self.small_first_shard: if self.small_first_shard:
self.tensors.append(dict()) self.tensors.append({})
self.add_architecture() self.add_architecture()
def verify_arguments(self) -> None: def format_shard_names(self, path: Path) -> list[Path]:
if len(self.tensors) == 1: if len(self.tensors) == 1:
logger.warning("Model fails split requirements, not splitting") return [path]
return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
# no shards are created when writing vocab so make one def open_output_file(self, path: Path | None = None) -> None:
if not self.tensors or len(self.tensors) == 0:
self.tensors = [dict()]
def format_shard_names(self, path: os.PathLike[str] | str | None = None) -> list[os.PathLike[str]]:
pathobj = Path(path)
if len(self.tensors) == 1:
return [pathobj]
shard_names = []
for i in range(len(self.tensors)):
shard_names.append(pathobj.with_name(SHARD_NAME_FORMAT.format(pathobj.stem, i + 1, len(self.tensors))))
return shard_names
def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
# allow calling this multiple times as long as the path is the same # allow calling this multiple times as long as the path is the same
return return
@ -141,14 +122,14 @@ class GGUFWriter:
self.path = path self.path = path
if self.path is not None: if self.path is not None:
self.fout = [] self.print_plan()
for fout in self.format_shard_names(self.path): self.fout = [open(filename, "wb") for filename in self.format_shard_names(self.path)]
self.fout.append(open(fout, "wb"))
self.state = WriterState.EMPTY self.state = WriterState.EMPTY
def print_plan(self, path: os.PathLike[str] | str | None = None) -> None: def print_plan(self) -> None:
logger.info("Writing the following files:") logger.info("Writing the following files:")
filenames = self.format_shard_names(path) assert self.path is not None
filenames = self.format_shard_names(self.path)
assert len(filenames) == len(self.tensors) assert len(filenames) == len(self.tensors)
for name, tensors in zip(filenames, self.tensors): for name, tensors in zip(filenames, self.tensors):
logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}") logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
@ -162,24 +143,28 @@ class GGUFWriter:
return return
total_tensors = sum(len(t) for t in self.tensors) total_tensors = sum(len(t) for t in self.tensors)
for i in range(len(self.fout)): assert self.fout is not None
total_splits = len(self.fout)
for i in range(total_splits):
# just see whether it exists # just see whether it exists
try: try:
self.kv_data[i] self.kv_data[i]
except IndexError: except IndexError:
self.kv_data.append(dict()) self.kv_data.append(dict())
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16) self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(len(self.fout), GGUFValueType.UINT16) self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32) self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None: def write_header_to_file(self, path: Path | None = None) -> None:
self.verify_arguments() if len(self.tensors) == 1:
self.print_plan(path) logger.warning("Model fails split requirements, not splitting")
self.open_output_file(path) self.open_output_file(path)
if self.state is not WriterState.EMPTY: if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}') raise ValueError(f'Expected output file to be empty, got {self.state}')
assert self.fout is not None
assert len(self.fout) == len(self.tensors) assert len(self.fout) == len(self.tensors)
assert len(self.kv_data) == 1 assert len(self.kv_data) == 1
@ -216,7 +201,6 @@ class GGUFWriter:
assert self.fout is not None assert self.fout is not None
for fout, tensors in zip(self.fout, self.tensors): for fout, tensors in zip(self.fout, self.tensors):
assert fout is not None
ti_data = bytearray() ti_data = bytearray()
offset_tensor = 0 offset_tensor = 0
@ -235,7 +219,7 @@ class GGUFWriter:
self.state = WriterState.TI_DATA self.state = WriterState.TI_DATA
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
if key in self.kv_data: if any(key in kv_data for kv_data in self.kv_data):
raise ValueError(f'Duplicated key name {key!r}') raise ValueError(f'Duplicated key name {key!r}')
self.kv_data[0][key] = GGUFValue(value=val, type=vtype) self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
@ -279,9 +263,6 @@ class GGUFWriter:
self.add_key_value(key, val, GGUFValueType.STRING) self.add_key_value(key, val, GGUFValueType.STRING)
def add_array(self, key: str, val: Sequence[Any]) -> None: def add_array(self, key: str, val: Sequence[Any]) -> None:
if not isinstance(val, Sequence):
raise ValueError("Value must be a sequence for array type")
self.add_key_value(key, val, GGUFValueType.ARRAY) self.add_key_value(key, val, GGUFValueType.ARRAY)
@staticmethod @staticmethod
@ -295,9 +276,8 @@ class GGUFWriter:
if self.state is not WriterState.NO_FILE: if self.state is not WriterState.NO_FILE:
raise ValueError(f'Expected output file to be not yet opened, got {self.state}') raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
for tensors in self.tensors: if any(name in tensors for tensors in self.tensors):
if name in tensors: raise ValueError(f'Duplicated tensor name {name!r}')
raise ValueError(f'Duplicated tensor name {name!r}')
if raw_dtype is None: if raw_dtype is None:
if tensor_dtype == np.float16: if tensor_dtype == np.float16:
@ -321,10 +301,8 @@ class GGUFWriter:
if tensor_dtype == np.uint8: if tensor_dtype == np.uint8:
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
# create splits as necessary, such as to start it off # split when over tensor limit
if (len(self.tensors) == self.small_first_shard \ if (self.split_max_tensors != 0 and \
# or split when over tensor limit
or self.split_max_tensors != 0 and \
len(self.tensors[-1]) >= self.split_max_tensors \ len(self.tensors[-1]) >= self.split_max_tensors \
# or split when over size limit # or split when over size limit
or self.split_max_size != 0 and \ or self.split_max_size != 0 and \
@ -369,7 +347,6 @@ class GGUFWriter:
tensor.byteswap(inplace=True) tensor.byteswap(inplace=True)
for fout in self.fout: for fout in self.fout:
assert fout is not None
self.write_padding(fout, fout.tell()) self.write_padding(fout, fout.tell())
tensor.tofile(fout) tensor.tofile(fout)
self.write_padding(fout, tensor.nbytes) self.write_padding(fout, tensor.nbytes)
@ -382,12 +359,10 @@ class GGUFWriter:
assert self.fout is not None assert self.fout is not None
for fout in self.fout: for fout in self.fout:
assert fout is not None
self.write_padding(fout, fout.tell()) self.write_padding(fout, fout.tell())
if self.temp_file is None: if self.temp_file is None:
for fout, tensors in zip(self.fout, self.tensors): for fout, tensors in zip(self.fout, self.tensors):
assert fout is not None
bar = None bar = None
if progress: if progress:
@ -409,7 +384,8 @@ class GGUFWriter:
else: else:
self.temp_file.seek(0) self.temp_file.seek(0)
shutil.copyfileobj(self.temp_file, self.fout) assert self.fout is not None
shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
self.flush() self.flush()
self.temp_file.close() self.temp_file.close()
@ -418,14 +394,12 @@ class GGUFWriter:
def flush(self) -> None: def flush(self) -> None:
assert self.fout is not None assert self.fout is not None
for fout in self.fout: for fout in self.fout:
assert fout is not None
fout.flush() fout.flush()
def close(self) -> None: def close(self) -> None:
if self.fout is not None: if self.fout is not None:
for fout in self.fout: for fout in self.fout:
if fout is not None: fout.close()
fout.close()
self.fout = [] self.fout = []
def add_architecture(self) -> None: def add_architecture(self) -> None:
@ -705,12 +679,11 @@ class GGUFWriter:
return kv_data return kv_data
def _write_packed(self, fout: BufferedWriter, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: def _write_packed(self, fout: BufferedWriter, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
assert fout is not None
fout.write(self._pack(fmt, value, skip_pack_prefix)) fout.write(self._pack(fmt, value, skip_pack_prefix))
@staticmethod @staticmethod
def format_n_bytes_to_str(num: int) -> str: def format_n_bytes_to_str(num: int) -> str:
if num == METADATA_ONLY_INDICATOR: if num == 0:
return "negligible - metadata only" return "negligible - metadata only"
fnum = float(num) fnum = float(num)
for unit in ("", "K", "M", "G"): for unit in ("", "K", "M", "G"):