fix typing and clean up
This commit is contained in:
parent
f7ecd99691
commit
9d7f694438
2 changed files with 39 additions and 62 deletions
|
@ -66,7 +66,7 @@ class Model:
|
|||
model_arch: gguf.MODEL_ARCH
|
||||
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
|
||||
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = 0, small_first_shard: bool = 0):
|
||||
model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
||||
if type(self) is Model:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
self.dir_model = dir_model
|
||||
|
@ -2875,6 +2875,10 @@ def main() -> None:
|
|||
"auto": gguf.LlamaFileType.GUESSED,
|
||||
}
|
||||
|
||||
if args.use_temp_file and (args.split_max_tensors > 0 or args.split_max_size != "0"):
|
||||
logger.error("Error: Cannot use temp file when splitting")
|
||||
sys.exit(1)
|
||||
|
||||
if args.outfile is not None:
|
||||
fname_out = args.outfile
|
||||
else:
|
||||
|
|
|
@ -9,7 +9,7 @@ from dataclasses import dataclass
|
|||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from io import BufferedWriter
|
||||
from typing import IO, Any, Sequence, Mapping, TypeAlias
|
||||
from typing import IO, Any, Sequence, Mapping
|
||||
from string import ascii_letters, digits
|
||||
|
||||
import numpy as np
|
||||
|
@ -33,11 +33,6 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
|
||||
NUM_SHARD_KV_DATA = 3
|
||||
METADATA_ONLY_INDICATOR = -1
|
||||
|
||||
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType | None]] # {key: (value, type)}
|
||||
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType | None] # (tensor name, tensor data, tensor dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -65,7 +60,7 @@ class WriterState(Enum):
|
|||
|
||||
class GGUFWriter:
|
||||
fout: list[BufferedWriter] | None
|
||||
path: os.PathLike[str] | str | None
|
||||
path: Path | None
|
||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||
tensors: list[dict[str, TensorInfo]]
|
||||
kv_data: list[dict[str, GGUFValue]]
|
||||
|
@ -88,15 +83,15 @@ class GGUFWriter:
|
|||
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False
|
||||
):
|
||||
self.fout = []
|
||||
self.path = path
|
||||
self.fout = None
|
||||
self.path = Path(path) if path else None
|
||||
self.arch = arch
|
||||
self.endianess = endianess
|
||||
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
|
||||
self.use_temp_file = use_temp_file
|
||||
self.temp_file = None
|
||||
self.tensors = []
|
||||
self.kv_data = [dict()]
|
||||
self.tensors = [{}]
|
||||
self.kv_data = [{}]
|
||||
self.split_max_tensors = split_max_tensors
|
||||
self.split_max_size = split_max_size
|
||||
self.dry_run = dry_run
|
||||
|
@ -107,30 +102,16 @@ class GGUFWriter:
|
|||
self.state = WriterState.NO_FILE
|
||||
|
||||
if self.small_first_shard:
|
||||
self.tensors.append(dict())
|
||||
self.tensors.append({})
|
||||
|
||||
self.add_architecture()
|
||||
|
||||
def verify_arguments(self) -> None:
|
||||
def format_shard_names(self, path: Path) -> list[Path]:
|
||||
if len(self.tensors) == 1:
|
||||
logger.warning("Model fails split requirements, not splitting")
|
||||
return [path]
|
||||
return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))]
|
||||
|
||||
# no shards are created when writing vocab so make one
|
||||
if not self.tensors or len(self.tensors) == 0:
|
||||
self.tensors = [dict()]
|
||||
|
||||
def format_shard_names(self, path: os.PathLike[str] | str | None = None) -> list[os.PathLike[str]]:
|
||||
pathobj = Path(path)
|
||||
if len(self.tensors) == 1:
|
||||
return [pathobj]
|
||||
|
||||
shard_names = []
|
||||
for i in range(len(self.tensors)):
|
||||
shard_names.append(pathobj.with_name(SHARD_NAME_FORMAT.format(pathobj.stem, i + 1, len(self.tensors))))
|
||||
|
||||
return shard_names
|
||||
|
||||
def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
def open_output_file(self, path: Path | None = None) -> None:
|
||||
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
|
||||
# allow calling this multiple times as long as the path is the same
|
||||
return
|
||||
|
@ -141,14 +122,14 @@ class GGUFWriter:
|
|||
self.path = path
|
||||
|
||||
if self.path is not None:
|
||||
self.fout = []
|
||||
for fout in self.format_shard_names(self.path):
|
||||
self.fout.append(open(fout, "wb"))
|
||||
self.print_plan()
|
||||
self.fout = [open(filename, "wb") for filename in self.format_shard_names(self.path)]
|
||||
self.state = WriterState.EMPTY
|
||||
|
||||
def print_plan(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
def print_plan(self) -> None:
|
||||
logger.info("Writing the following files:")
|
||||
filenames = self.format_shard_names(path)
|
||||
assert self.path is not None
|
||||
filenames = self.format_shard_names(self.path)
|
||||
assert len(filenames) == len(self.tensors)
|
||||
for name, tensors in zip(filenames, self.tensors):
|
||||
logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
|
||||
|
@ -162,24 +143,28 @@ class GGUFWriter:
|
|||
return
|
||||
|
||||
total_tensors = sum(len(t) for t in self.tensors)
|
||||
for i in range(len(self.fout)):
|
||||
assert self.fout is not None
|
||||
total_splits = len(self.fout)
|
||||
for i in range(total_splits):
|
||||
# just see whether it exists
|
||||
try:
|
||||
self.kv_data[i]
|
||||
except IndexError:
|
||||
self.kv_data.append(dict())
|
||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16)
|
||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(len(self.fout), GGUFValueType.UINT16)
|
||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16)
|
||||
self.kv_data[i][Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32)
|
||||
|
||||
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
|
||||
self.verify_arguments()
|
||||
self.print_plan(path)
|
||||
def write_header_to_file(self, path: Path | None = None) -> None:
|
||||
if len(self.tensors) == 1:
|
||||
logger.warning("Model fails split requirements, not splitting")
|
||||
|
||||
self.open_output_file(path)
|
||||
|
||||
if self.state is not WriterState.EMPTY:
|
||||
raise ValueError(f'Expected output file to be empty, got {self.state}')
|
||||
|
||||
assert self.fout is not None
|
||||
assert len(self.fout) == len(self.tensors)
|
||||
assert len(self.kv_data) == 1
|
||||
|
||||
|
@ -216,7 +201,6 @@ class GGUFWriter:
|
|||
assert self.fout is not None
|
||||
|
||||
for fout, tensors in zip(self.fout, self.tensors):
|
||||
assert fout is not None
|
||||
ti_data = bytearray()
|
||||
offset_tensor = 0
|
||||
|
||||
|
@ -235,7 +219,7 @@ class GGUFWriter:
|
|||
self.state = WriterState.TI_DATA
|
||||
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||
if key in self.kv_data:
|
||||
if any(key in kv_data for kv_data in self.kv_data):
|
||||
raise ValueError(f'Duplicated key name {key!r}')
|
||||
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
|
||||
|
@ -279,9 +263,6 @@ class GGUFWriter:
|
|||
self.add_key_value(key, val, GGUFValueType.STRING)
|
||||
|
||||
def add_array(self, key: str, val: Sequence[Any]) -> None:
|
||||
if not isinstance(val, Sequence):
|
||||
raise ValueError("Value must be a sequence for array type")
|
||||
|
||||
self.add_key_value(key, val, GGUFValueType.ARRAY)
|
||||
|
||||
@staticmethod
|
||||
|
@ -295,9 +276,8 @@ class GGUFWriter:
|
|||
if self.state is not WriterState.NO_FILE:
|
||||
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
|
||||
|
||||
for tensors in self.tensors:
|
||||
if name in tensors:
|
||||
raise ValueError(f'Duplicated tensor name {name!r}')
|
||||
if any(name in tensors for tensors in self.tensors):
|
||||
raise ValueError(f'Duplicated tensor name {name!r}')
|
||||
|
||||
if raw_dtype is None:
|
||||
if tensor_dtype == np.float16:
|
||||
|
@ -321,10 +301,8 @@ class GGUFWriter:
|
|||
if tensor_dtype == np.uint8:
|
||||
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
||||
|
||||
# create splits as necessary, such as to start it off
|
||||
if (len(self.tensors) == self.small_first_shard \
|
||||
# or split when over tensor limit
|
||||
or self.split_max_tensors != 0 and \
|
||||
# split when over tensor limit
|
||||
if (self.split_max_tensors != 0 and \
|
||||
len(self.tensors[-1]) >= self.split_max_tensors \
|
||||
# or split when over size limit
|
||||
or self.split_max_size != 0 and \
|
||||
|
@ -369,7 +347,6 @@ class GGUFWriter:
|
|||
tensor.byteswap(inplace=True)
|
||||
|
||||
for fout in self.fout:
|
||||
assert fout is not None
|
||||
self.write_padding(fout, fout.tell())
|
||||
tensor.tofile(fout)
|
||||
self.write_padding(fout, tensor.nbytes)
|
||||
|
@ -382,12 +359,10 @@ class GGUFWriter:
|
|||
assert self.fout is not None
|
||||
|
||||
for fout in self.fout:
|
||||
assert fout is not None
|
||||
self.write_padding(fout, fout.tell())
|
||||
|
||||
if self.temp_file is None:
|
||||
for fout, tensors in zip(self.fout, self.tensors):
|
||||
assert fout is not None
|
||||
bar = None
|
||||
|
||||
if progress:
|
||||
|
@ -409,7 +384,8 @@ class GGUFWriter:
|
|||
else:
|
||||
self.temp_file.seek(0)
|
||||
|
||||
shutil.copyfileobj(self.temp_file, self.fout)
|
||||
assert self.fout is not None
|
||||
shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1])
|
||||
self.flush()
|
||||
self.temp_file.close()
|
||||
|
||||
|
@ -418,14 +394,12 @@ class GGUFWriter:
|
|||
def flush(self) -> None:
|
||||
assert self.fout is not None
|
||||
for fout in self.fout:
|
||||
assert fout is not None
|
||||
fout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
if self.fout is not None:
|
||||
for fout in self.fout:
|
||||
if fout is not None:
|
||||
fout.close()
|
||||
fout.close()
|
||||
self.fout = []
|
||||
|
||||
def add_architecture(self) -> None:
|
||||
|
@ -705,12 +679,11 @@ class GGUFWriter:
|
|||
return kv_data
|
||||
|
||||
def _write_packed(self, fout: BufferedWriter, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None:
|
||||
assert fout is not None
|
||||
fout.write(self._pack(fmt, value, skip_pack_prefix))
|
||||
|
||||
@staticmethod
|
||||
def format_n_bytes_to_str(num: int) -> str:
|
||||
if num == METADATA_ONLY_INDICATOR:
|
||||
if num == 0:
|
||||
return "negligible - metadata only"
|
||||
fnum = float(num)
|
||||
for unit in ("", "K", "M", "G"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue