From ce7e6985d2b233f0bbd62969689962635aca0898 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Wed, 5 Jun 2024 18:29:39 -0400 Subject: [PATCH] form shards while adding tensors, SHA256 sums agree with master --- convert-hf-to-gguf.py | 8 +- gguf-py/gguf/gguf_manager.py | 341 +++++++++++++++++------------------ gguf-py/gguf/gguf_writer.py | 5 +- 3 files changed, 176 insertions(+), 178 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index b4399f680..b6fd4bc4b 100644 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -327,6 +327,7 @@ class Model: def write(self): self.write_tensors() + self.gguf_writer.init_shards() self.gguf_writer.write_header_to_file() self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_tensors_to_file(progress=True) @@ -335,6 +336,7 @@ class Model: def write_vocab(self): if self.gguf_writer.split_arguments.split: raise ValueError('Splitting the vocabulary is not supported') + self.gguf_writer.init_shards() self.gguf_writer.write_header_to_file() self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() @@ -2816,8 +2818,8 @@ def parse_args() -> argparse.Namespace: help="only print out a split plan and exit, without writing any new files" ) parser.add_argument( - "--large-first-shard", action="store_true", - help="include tensors in the first shard when splitting (default: metadata only)" + "--small-first-shard", action="store_true", + help="do not add tensors to the first shard (disabled by default)" ) return parser.parse_args() @@ -2853,7 +2855,7 @@ def main() -> None: if args.split_max_tensors and args.split_max_size: raise ValueError("Can't specify both --split-max-tensors and --split-max-size") - split_arguments = gguf.SplitArguments(args=args) if args.split else gguf.SplitArguments() + split_arguments = gguf.SplitArguments(args) ftype_map = { "f32": gguf.LlamaFileType.ALL_F32, diff --git a/gguf-py/gguf/gguf_manager.py b/gguf-py/gguf/gguf_manager.py index 2605b816a..2fcaf3edf 100644 --- a/gguf-py/gguf/gguf_manager.py +++ b/gguf-py/gguf/gguf_manager.py @@ -4,7 +4,6 @@ import os from enum import IntEnum from typing import TYPE_CHECKING, Any, Sequence from argparse import Namespace -from math import ceil from collections import deque import numpy as np @@ -21,14 +20,15 @@ from .gguf_writer import GGUFWriter, WriterState SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" +METADATA_ONLY_INDICATOR = -1 LLM_KV_SPLIT_NO = "split.no" LLM_KV_SPLIT_COUNT = "split.count" LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count" -SplitTensorsPerFile: TypeAlias = deque[tuple[os.PathLike[str], deque[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)] KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)} TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType] # (tensor name, tensor data, tensor dtype) +Shard: TypeAlias = list[os.PathLike[str], int, int, deque[TensorTempData]] # [shard filename, shard tensor count, shard size, [tensor data]] class SplitStyle(IntEnum): @@ -38,56 +38,182 @@ class SplitStyle(IntEnum): class SplitArguments: - def __init__(self, args: Namespace = None) -> None: - self.split = args.split if args else False - self.split_max_tensors = args.split_max_tensors if args else 0 - self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args and args.split_max_size else 0 - self.dry_run = args.dry_run if args else False - self.small_first_shard = not args.large_first_shard if args else False - self.split_style = SplitStyle.NONE if not self.split or not args \ + def __init__(self, args: Namespace) -> None: + self.split = args.split + self.split_max_tensors = args.split_max_tensors if args.split else 0 + self.split_max_size = GGUFManager.split_str_to_n_bytes(args.split_max_size) if args.split and args.split_max_size else 0 + self.split_style = SplitStyle.NONE if not self.split \ else SplitStyle.TENSORS if self.split_max_tensors \ else SplitStyle.SIZE + self.dry_run = args.dry_run + self.small_first_shard = args.small_first_shard -class SplitStrategy(deque): - data: SplitTensorsPerFile +class GGUFManager(GGUFWriter): + kv_data: KVTempData + tensors: list[TensorTempData] + split_arguments: SplitArguments + shards: list[Shard] + shard_writers: list[GGUFWriter] - def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str, - split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE, - ): - super().__init__() + def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments, + use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE + ) -> None: + # we intentionally don't call superclass constructor + self.arch = arch + self.path = path + self.endianess = endianess + self.kv_data = {} + self.shards = [] + self.shard_writers = [] + self.total_tensors = 0 + self.use_temp_file = use_temp_file + self.split_arguments = split_arguments + self.recent_key = None + self.state = WriterState.EMPTY - if split_arguments.split_style == SplitStyle.NONE: - self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess))) + if self.split_arguments.small_first_shard: + self.shards.append(["", 0, METADATA_ONLY_INDICATOR, None]) - elif split_arguments.split_style == SplitStyle.TENSORS: - total_shards = ceil(len(model) / split_arguments.split_max_tensors) + split_arguments.small_first_shard - shard_files = [fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, total_shards)) for i in range(total_shards)] + def init_shards(self) -> None: + self.total_tensors = sum(shard[1] for shard in self.shards) + total_size = sum(shard[2] for shard in self.shards) - if split_arguments.small_first_shard: - self.append((shard_files[0], None, GGUFWriter(shard_files[0], arch, use_temp_file=use_temp_file, endianess=endianess))) + # check if we need to split + if self.split_arguments.split_max_tensors and self.total_tensors < self.split_arguments.split_max_tensors: + print("Model has fewer tensors than the split threshold, not splitting") + self.split_style = SplitStyle.NONE - for i, shard in enumerate(shard_files[split_arguments.small_first_shard:]): - start = i * split_arguments.split_max_tensors - stop = min((i + 1) * split_arguments.split_max_tensors, len(model)) - self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess))) + if self.split_arguments.split_max_size and total_size < self.split_arguments.split_max_size: + print("Model has smaller size than the split threshold, not splitting") + self.split_style = SplitStyle.NONE - elif split_arguments.split_style == SplitStyle.SIZE: - shards = [[model[0]]] + # no shards are created when writing vocab so make one + if not self.shards: + self.shards.append(["", 0, METADATA_ONLY_INDICATOR, None]) - # we have to determine the shards first to determine how many shards there will be in total - two passes - for i, shard in enumerate(model[1:]): - if SplitStrategy.get_tensor_size(shard[1]) + sum(SplitStrategy.get_tensor_size(t[1]) for t in shards[-1]) > split_arguments.split_max_size: - shards.append([shard]) - else: - shards[-1].append(shard) + # format shard names + if len(self.shards) == 1: + self.shards[0][0] = self.path + else: + for i in range(len(self.shards)): + self.shards[i][0] = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards))) - if split_arguments.small_first_shard: - shards.insert(0, None) + # print shard info + print("\nWriting the following files:") + for (path, tensor_ct, size, _) in self.shards: + print(f" {path}: n_tensors = {tensor_ct}, total_size = {GGUFManager.format_n_bytes_to_str(size)}") + print() - for i, shard in enumerate(shards): - outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, len(shards))) - self.append((outname, shard, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess))) + if self.split_arguments.dry_run: + print("\nDry run, not writing files") + exit() + + # we don't want to initialize GGUFWriters until now because they create files + for i, (path, _, _, tensors) in enumerate(self.shards): + # dont_add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards + writer = GGUFWriter(path, self.arch, use_temp_file=self.use_temp_file, + endianess=self.endianess, dont_add_architecture=not (i == 0)) + + # only the first shard needs all the KV data + if i == 0: + for key, (value, etype) in self.kv_data.items(): + writer.add_key(key) + writer.add_val(value, etype) + + # add split metadata unless it's one file - small first shard splits even with SplitStyle.NONE + if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard: + writer.add_uint16(LLM_KV_SPLIT_NO, i) + writer.add_uint16(LLM_KV_SPLIT_COUNT, len(self.shards)) + writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors) + + # add tensors, deque popleft() ensures references to eager tensors are not kept + while True: + try: + (name, tensor, dtype) = tensors.popleft() + writer.add_tensor(name, tensor, raw_dtype=dtype) + except: + break + + self.shard_writers.append(writer) + + def write_header_to_file(self) -> None: + if self.state is not WriterState.EMPTY: + raise ValueError(f'Expected GGUFManager state to be EMPTY, got {self.state}') + + for writer in self.shard_writers: + writer.write_header_to_file() + + self.state = WriterState.HEADER + + def write_kv_data_to_file(self) -> None: + if self.state is not WriterState.HEADER: + raise ValueError(f'Expected GGUFManager state to be HEADER, got {self.state}') + + for writer in self.shard_writers: + writer.write_kv_data_to_file() + + self.state = WriterState.KV_DATA + + def write_tensors_to_file(self, progress: bool = False) -> None: + if self.state is not WriterState.KV_DATA: + raise ValueError(f'Expected GGUFManager state to be KV_DATA, got {self.state}') + + running_total = self.total_tensors + for i in range(len(self.shard_writers)): + writer = self.shard_writers[i] + is_metadata = writer.ti_data_count == 0 + if is_metadata: + print(f"Writing to shard {i + 1}/{len(self.shards)} with metadata only") + else: + print(f"Writing to shard {i + 1}/{len(self.shards)} with {writer.ti_data_count}/{running_total} remaining tensors (of {self.total_tensors} total)") + running_total -= writer.ti_data_count + writer.write_tensors_to_file(progress=(progress and not is_metadata)) + del writer + + self.state = WriterState.TI_DATA + + # override add_key, add_val to handle kv data separately + def add_key(self, key: str) -> None: + self.recent_key = key + + def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None: + if self.recent_key is None: + raise ValueError("No key set for value") + self.kv_data[self.recent_key] = (val, vtype) + + # need to handle arrays separately + def add_array(self, key: str, val: Sequence[Any]) -> None: + if not isinstance(val, Sequence): + raise ValueError(f'Expected a sequence for {key}, got {type(val)}') + self.kv_data[key] = (val, GGUFValueType.ARRAY) + + def add_tensor( + self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, + raw_dtype: GGMLQuantizationType | None = None, + ) -> None: + # we build splits as tensors are added so we need logic to figure out when to split + # logic is all in the conditional because it short-circuits, otherwise accessing self.shards[-1] would throw an error + + # create a first shard to start it off + if (len(self.shards) == self.split_arguments.small_first_shard \ + # or split when over tensor limit + or (self.split_arguments.split_style == SplitStyle.TENSORS \ + and self.shards[-1][1] >= self.split_arguments.split_max_tensors) \ + # or split when over size limit + or (self.split_arguments.split_style == SplitStyle.SIZE \ + and self.shards[-1][2] + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)): + + # we fill in the name later when we know how many shards there are + self.shards.append(["", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])]) + else: + self.shards[-1][1] += 1 + self.shards[-1][2] += GGUFManager.get_tensor_size(tensor) + self.shards[-1][3].append((name, tensor, raw_dtype)) + + def close(self) -> None: + for writer in self.shard_writers: + writer.close() @staticmethod def get_tensor_size(tensor) -> int: @@ -118,142 +244,11 @@ class SplitStrategy(deque): @staticmethod def format_n_bytes_to_str(num: int) -> str: + if num == METADATA_ONLY_INDICATOR: + return "negligible - metadata only" num = float(num) for unit in ("", "K", "M", "G"): if abs(num) < 1024.0: return f"{num:3.1f}{unit}" num /= 1024.0 - return f"{num:.1f}T - over 1TB, --split recommended" - - -class GGUFManager(GGUFWriter): - kv_data: KVTempData - tensors: list[TensorTempData] - split_arguments: SplitArguments - split_strategy: SplitStrategy - - def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments, - use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE - ) -> None: - # we intentionally don't call superclass constructor - self.arch = arch - self.path = path - self.endianess = endianess - self.kv_data = {} - self.tensors = [] - self.split_strategy = None - self.total_shards = 0 - self.total_tensors = 0 - self.use_temp_file = use_temp_file - self.split_arguments = split_arguments - self.recent_key = None - self.state = WriterState.EMPTY - self.add_architecture() - - def write_header_to_file(self) -> None: - if self.state is not WriterState.EMPTY: - raise ValueError(f'Expected GGUFManager state to be EMPTY, got {self.state}') - - self.total_tensors = len(self.tensors) - total_size = sum(SplitStrategy.get_tensor_size(tensor[1]) for tensor in self.tensors) - - if self.split_arguments.split_max_tensors and self.total_tensors < self.split_arguments.split_max_tensors: - print("Model has fewer tensors than the split threshold, not splitting") - self.split_style = SplitStyle.NONE - - if self.split_arguments.split_max_size and total_size < self.split_arguments.split_max_size: - print("Model has smaller size than the split threshold, not splitting") - self.split_style = SplitStyle.NONE - - self.split_strategy = SplitStrategy(self.path, self.tensors, self.arch, self.split_arguments, - use_temp_file=self.use_temp_file, endianess=self.endianess) - del self.tensors - self.total_shards = len(self.split_strategy) - - print("\nWriting the following files:") - for (shard_path, shard_tensors, _) in self.split_strategy: - size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only" - print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}") - - if self.split_arguments.dry_run: - print("\nDry run, not writing files") - # instantiating GGUFWriters creates files - for name, _, _ in self.split_strategy: - os.remove(name) - return - - self.state = WriterState.HEADER - - def write_kv_data_to_file(self) -> None: - if self.split_arguments.dry_run: - return - - if self.state is not WriterState.HEADER: - raise ValueError(f'Expected GGUFManager state to be HEADER, got {self.state}') - - # only the first shard needs all the KV data - for key, (value, etype) in self.kv_data.items(): - self.split_strategy[0][2].add_key(key) - self.split_strategy[0][2].add_val(value, etype) - - # the other shards need shard data - if self.split_arguments.split_style != SplitStyle.NONE: - for i, (_, _, writer) in enumerate(self.split_strategy): - writer.add_uint16(LLM_KV_SPLIT_NO, i) - writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards) - writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors) - - self.state = WriterState.KV_DATA - - def write_tensors_to_file(self, progress: bool = False) -> None: - if self.split_arguments.dry_run: - return - - if self.state is not WriterState.KV_DATA: - raise ValueError(f'Expected GGUFManager state to be KV_DATA, got {self.state}') - - running_total = self.total_tensors - for ct in range(self.total_shards): - (_, tensors, writer) = self.split_strategy.popleft() - tensors = deque(tensors) if tensors else None - - shard_num_tensors = len(tensors) if tensors else 0 - print(f"Writing to shard {ct}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)") - running_total -= shard_num_tensors - - for _ in range(shard_num_tensors): - (name, tensor, dtype) = tensors.popleft() - writer.add_tensor(name, tensor, raw_dtype=dtype) - - # need to write everything down here - writer.write_header_to_file() - writer.write_kv_data_to_file() - writer.write_tensors_to_file(progress=progress) - del tensors - - self.state = WriterState.TI_DATA - - # override add_key, add_val to handle kv data separately - def add_key(self, key: str) -> None: - self.recent_key = key - - def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None: - if self.recent_key is None: - raise ValueError("No key set for value") - self.kv_data[self.recent_key] = (val, vtype) - - # need to handle arrays separately - def add_array(self, key: str, val: Sequence[Any]) -> None: - if not isinstance(val, Sequence): - raise ValueError(f'Expected a sequence for {key}, got {type(val)}') - self.kv_data[key] = (val, GGUFValueType.ARRAY) - - def add_tensor( - self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, - raw_dtype: GGMLQuantizationType | None = None, - ) -> None: - self.tensors.append((name, tensor, raw_dtype)) - - def close(self) -> None: - for _, _, writer in self.split_strategy: - writer.close() \ No newline at end of file + return f"{num:.1f}T - over 1TB, --split recommended" \ No newline at end of file diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 7ef321b91..294f4d06d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -57,7 +57,7 @@ class GGUFWriter: def __init__( self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True, - endianess: GGUFEndian = GGUFEndian.LITTLE, + endianess: GGUFEndian = GGUFEndian.LITTLE, dont_add_architecture: bool = False ): self.fout = open(path, "wb") self.arch = arch @@ -77,7 +77,8 @@ class GGUFWriter: )) self.state = WriterState.EMPTY - self.add_architecture() + if not dont_add_architecture: + self.add_architecture() def write_header_to_file(self) -> None: if self.state is not WriterState.EMPTY: