remove SplitStrategy, SplitArguments

This commit is contained in:
Christian Zhou-Zheng 2024-06-09 13:08:06 -04:00
parent 0471f67f4f
commit 5a96b8f27f
2 changed files with 55 additions and 82 deletions

View file

@ -66,7 +66,7 @@ class Model:
model_arch: gguf.MODEL_ARCH model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
split_arguments: gguf.SplitArguments, model_name: str | None): model_name: str | None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = 0, small_first_shard: bool = 0):
if type(self) is Model: if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model self.dir_model = dir_model
@ -97,8 +97,8 @@ class Model:
ftype_lw: str = ftype_up.lower() ftype_lw: str = ftype_up.lower()
# allow templating the file name with the output ftype, useful with the "auto" ftype # allow templating the file name with the output ftype, useful with the "auto" ftype
self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up) self.fname_out = fname_out.parent / fname_out.name.format(ftype_lw, outtype=ftype_lw, ftype=ftype_lw, OUTTYPE=ftype_up, FTYPE=ftype_up)
self.gguf_writer = gguf.GGUFWriter(None, gguf.MODEL_ARCH_NAMES[self.model_arch], split_arguments, self.gguf_writer = gguf.GGUFWriter(None, gguf.MODEL_ARCH_NAMES[self.model_arch],endianess=self.endianess, use_temp_file=self.use_temp_file,
endianess=self.endianess, use_temp_file=self.use_temp_file) split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@classmethod @classmethod
def __init_subclass__(cls): def __init_subclass__(cls):
@ -334,7 +334,7 @@ class Model:
self.gguf_writer.close() self.gguf_writer.close()
def write_vocab(self): def write_vocab(self):
if self.gguf_writer.split_arguments.split_style != gguf.SplitStyle.NONE: if len(self.gguf_writer.tensors) != 1:
raise ValueError('Splitting the vocabulary is not supported') raise ValueError('Splitting the vocabulary is not supported')
self.gguf_writer.write_header_to_file(self.fname_out) self.gguf_writer.write_header_to_file(self.fname_out)
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
@ -2806,11 +2806,11 @@ def parse_args() -> argparse.Namespace:
help="increase output verbosity", help="increase output verbosity",
) )
parser.add_argument( parser.add_argument(
"--split-max-tensors", type=int, "--split-max-tensors", type=int, default=0,
help="max tensors in each split", help="max tensors in each split",
) )
parser.add_argument( parser.add_argument(
"--split-max-size", type=str, "--split-max-size", type=str, default="0",
help="max size per split N(M|G)", help="max size per split N(M|G)",
) )
parser.add_argument( parser.add_argument(
@ -2825,6 +2825,24 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args() return parser.parse_args()
def split_str_to_n_bytes(split_str: str) -> int:
if split_str.endswith("K"):
n = int(split_str[:-1]) * 1000
elif split_str.endswith("M"):
n = int(split_str[:-1]) * 1000 * 1000
elif split_str.endswith("G"):
n = int(split_str[:-1]) * 1000 * 1000 * 1000
elif split_str.isnumeric():
n = int(split_str)
else:
raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G")
if n < 0:
raise ValueError(f"Invalid split size: {split_str}, must be positive")
return n
def main() -> None: def main() -> None:
args = parse_args() args = parse_args()
@ -2849,11 +2867,6 @@ def main() -> None:
logger.error(f'Error: {args.model} is not a directory') logger.error(f'Error: {args.model} is not a directory')
sys.exit(1) sys.exit(1)
if args.split_max_tensors and args.split_max_size:
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
split_arguments = gguf.SplitArguments(args)
ftype_map: dict[str, gguf.LlamaFileType] = { ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32, "f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16, "f16": gguf.LlamaFileType.MOSTLY_F16,
@ -2880,7 +2893,9 @@ def main() -> None:
sys.exit(1) sys.exit(1)
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file,
args.no_lazy, split_arguments, args.model_name) args.no_lazy, args.model_name, split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split)
logger.info("Set model parameters") logger.info("Set model parameters")
model_instance.set_gguf_parameters() model_instance.set_gguf_parameters()

View file

@ -5,8 +5,6 @@ import os
import shutil import shutil
import struct import struct
import tempfile import tempfile
from argparse import Namespace
from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path from pathlib import Path
@ -56,17 +54,6 @@ class GGUFValue:
type: GGUFValueType type: GGUFValueType
class SplitArguments:
def __init__(self, args: Namespace) -> None:
self.split_max_tensors = args.split_max_tensors if args.split_max_tensors else 0
self.split_max_size = GGUFWriter.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else 0
self.split_style = SplitStyle.TENSORS if self.split_max_tensors \
else SplitStyle.SIZE if self.split_max_size \
else SplitStyle.NONE
self.dry_run = args.dry_run
self.small_first_shard = args.no_tensor_first_split
class WriterState(Enum): class WriterState(Enum):
NO_FILE = auto() NO_FILE = auto()
EMPTY = auto() EMPTY = auto()
@ -76,12 +63,6 @@ class WriterState(Enum):
WEIGHTS = auto() WEIGHTS = auto()
class SplitStyle(Enum):
NONE = auto()
TENSORS = auto()
SIZE = auto()
class GGUFWriter: class GGUFWriter:
fout: list[BufferedWriter] | None fout: list[BufferedWriter] | None
path: os.PathLike[str] | str | None path: os.PathLike[str] | str | None
@ -104,40 +85,34 @@ class GGUFWriter:
} }
def __init__( def __init__(
self, path: os.PathLike[str] | str | None, arch: str, split_arguments: SplitArguments, self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE,
use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
):
self.fout = [] self.fout = []
self.path = path self.path = path
self.arch = arch self.arch = arch
self.endianess = endianess self.endianess = endianess
self.data_alignment = GGUF_DEFAULT_ALIGNMENT self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.split_arguments = split_arguments
self.use_temp_file = use_temp_file self.use_temp_file = use_temp_file
self.temp_file = None self.temp_file = None
self.tensors = [] self.tensors = []
self.kv_data = [dict()] self.kv_data = [dict()]
self.split_max_tensors = split_max_tensors
self.split_max_size = split_max_size
self.dry_run = dry_run
self.small_first_shard = small_first_shard
logger.info("gguf: This GGUF file is for {0} Endian only".format( logger.info("gguf: This GGUF file is for {0} Endian only".format(
"Big" if self.endianess == GGUFEndian.BIG else "Little", "Big" if self.endianess == GGUFEndian.BIG else "Little",
)) ))
self.state = WriterState.NO_FILE self.state = WriterState.NO_FILE
if self.split_arguments.small_first_shard: if self.small_first_shard:
self.tensors.append(dict()) self.tensors.append(dict())
self.add_architecture() self.add_architecture()
def verify_arguments(self) -> None: def verify_arguments(self) -> None:
total_tensors = sum(len(ti) for ti in self.tensors) if len(self.tensors) == 1:
total_size = sum(ti.nbytes for t in self.tensors for ti in t.values()) logger.warning("Model fails split requirements, not splitting")
if self.split_arguments.split_max_tensors and total_tensors < self.split_arguments.split_max_tensors:
logger.warning("Model has fewer tensors than the split threshold, not splitting")
self.split_style = SplitStyle.NONE
if self.split_arguments.split_max_size and total_size < self.split_arguments.split_max_size:
logger.warning("Model has smaller size than the split threshold, not splitting")
self.split_style = SplitStyle.NONE
# no shards are created when writing vocab so make one # no shards are created when writing vocab so make one
if not self.tensors or len(self.tensors) == 0: if not self.tensors or len(self.tensors) == 0:
@ -145,7 +120,7 @@ class GGUFWriter:
def format_shard_names(self, path: os.PathLike[str] | str | None = None) -> list[os.PathLike[str]]: def format_shard_names(self, path: os.PathLike[str] | str | None = None) -> list[os.PathLike[str]]:
pathobj = Path(path) pathobj = Path(path)
if self.split_arguments.split_style == SplitStyle.NONE: if len(self.tensors) == 1:
return [pathobj] return [pathobj]
shard_names = [] shard_names = []
@ -173,15 +148,16 @@ class GGUFWriter:
def print_plan(self, path: os.PathLike[str] | str | None = None) -> None: def print_plan(self, path: os.PathLike[str] | str | None = None) -> None:
logger.info("Writing the following files:") logger.info("Writing the following files:")
filenames = self.format_shard_names(path) filenames = self.format_shard_names(path)
for i in range(len(filenames)): assert len(filenames) == len(self.tensors)
logger.info(f"{filenames[i]}: n_tensors = {len(self.tensors[i])}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in self.tensors[i].values()))}") for name, tensors in zip(filenames, self.tensors):
logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}")
if self.split_arguments.dry_run: if self.dry_run:
logger.info("Dry run, not writing files") logger.info("Dry run, not writing files")
exit() exit()
def add_shard_kv_data(self) -> None: def add_shard_kv_data(self) -> None:
if self.split_arguments.split_style == SplitStyle.NONE: if len(self.tensors) == 1:
return return
total_tensors = sum(len(t) for t in self.tensors) total_tensors = sum(len(t) for t in self.tensors)
@ -318,8 +294,8 @@ class GGUFWriter:
if self.state is not WriterState.NO_FILE: if self.state is not WriterState.NO_FILE:
raise ValueError(f'Expected output file to be not yet opened, got {self.state}') raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
for i in range(len(self.tensors)): for tensors in self.tensors:
if name in self.tensors[i]: if name in tensors:
raise ValueError(f'Duplicated tensor name {name!r}') raise ValueError(f'Duplicated tensor name {name!r}')
if raw_dtype is None: if raw_dtype is None:
@ -345,13 +321,13 @@ class GGUFWriter:
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype) tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
# create splits as necessary, such as to start it off # create splits as necessary, such as to start it off
if (len(self.tensors) == self.split_arguments.small_first_shard \ if (len(self.tensors) == self.small_first_shard \
# or split when over tensor limit # or split when over tensor limit
or (self.split_arguments.split_style == SplitStyle.TENSORS \ or self.split_max_tensors != 0 and \
and len(self.tensors[-1]) >= self.split_arguments.split_max_tensors) \ len(self.tensors[-1]) >= self.split_max_tensors \
# or split when over size limit # or split when over size limit
or (self.split_arguments.split_style == SplitStyle.SIZE \ or self.split_max_size != 0 and \
and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_arguments.split_max_size)): sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size):
self.tensors.append(dict()) self.tensors.append(dict())
@ -409,25 +385,25 @@ class GGUFWriter:
self.write_padding(fout, fout.tell()) self.write_padding(fout, fout.tell())
if self.temp_file is None: if self.temp_file is None:
for i in range(len(self.fout)): for fout, tensors in zip(self.fout, self.tensors):
assert self.fout[i] is not None assert fout is not None
bar = None bar = None
if progress: if progress:
from tqdm import tqdm from tqdm import tqdm
total_bytes = sum(ti.nbytes for ti in self.tensors[i].values()) total_bytes = sum(ti.nbytes for ti in tensors.values())
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
# relying on the fact that Python dicts preserve insertion order (since 3.7) # relying on the fact that Python dicts preserve insertion order (since 3.7)
for ti in self.tensors[i].values(): for ti in tensors.values():
assert ti.tensor is not None # can only iterate once over the tensors assert ti.tensor is not None # can only iterate once over the tensors
assert ti.tensor.nbytes == ti.nbytes assert ti.tensor.nbytes == ti.nbytes
ti.tensor.tofile(self.fout[i]) ti.tensor.tofile(fout)
if bar is not None: if bar is not None:
bar.update(ti.nbytes) bar.update(ti.nbytes)
self.write_padding(self.fout[i], ti.nbytes) self.write_padding(fout, ti.nbytes)
ti.tensor = None ti.tensor = None
else: else:
self.temp_file.seek(0) self.temp_file.seek(0)
@ -731,24 +707,6 @@ class GGUFWriter:
assert fout is not None assert fout is not None
fout.write(self._pack(fmt, value, skip_pack_prefix)) fout.write(self._pack(fmt, value, skip_pack_prefix))
@staticmethod
def split_str_to_n_bytes(split_str: str) -> int:
if split_str.endswith("K"):
n = int(split_str[:-1]) * 1000
elif split_str.endswith("M"):
n = int(split_str[:-1]) * 1000 * 1000
elif split_str.endswith("G"):
n = int(split_str[:-1]) * 1000 * 1000 * 1000
elif split_str.isnumeric():
n = int(split_str)
else:
raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G")
if n <= 0:
raise ValueError(f"Invalid split size: {split_str}, must be positive")
return n
@staticmethod @staticmethod
def format_n_bytes_to_str(num: int) -> str: def format_n_bytes_to_str(num: int) -> str:
if num == METADATA_ONLY_INDICATOR: if num == METADATA_ONLY_INDICATOR: