form shards while adding tensors, SHA256 sums agree with master

This commit is contained in:
Christian Zhou-Zheng 2024-06-05 18:29:39 -04:00
parent 5ad397d610
commit ce7e6985d2
3 changed files with 176 additions and 178 deletions

View file

@ -327,6 +327,7 @@ class Model:
def write(self): def write(self):
self.write_tensors() self.write_tensors()
self.gguf_writer.init_shards()
self.gguf_writer.write_header_to_file() self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.write_tensors_to_file(progress=True)
@ -335,6 +336,7 @@ class Model:
def write_vocab(self): def write_vocab(self):
if self.gguf_writer.split_arguments.split: if self.gguf_writer.split_arguments.split:
raise ValueError('Splitting the vocabulary is not supported') raise ValueError('Splitting the vocabulary is not supported')
self.gguf_writer.init_shards()
self.gguf_writer.write_header_to_file() self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close() self.gguf_writer.close()
@ -2816,8 +2818,8 @@ def parse_args() -> argparse.Namespace:
help="only print out a split plan and exit, without writing any new files" help="only print out a split plan and exit, without writing any new files"
) )
parser.add_argument( parser.add_argument(
"--large-first-shard", action="store_true", "--small-first-shard", action="store_true",
help="include tensors in the first shard when splitting (default: metadata only)" help="do not add tensors to the first shard (disabled by default)"
) )
return parser.parse_args() return parser.parse_args()
@ -2853,7 +2855,7 @@ def main() -> None:
if args.split_max_tensors and args.split_max_size: if args.split_max_tensors and args.split_max_size:
raise ValueError("Can't specify both --split-max-tensors and --split-max-size") raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
split_arguments = gguf.SplitArguments(args=args) if args.split else gguf.SplitArguments() split_arguments = gguf.SplitArguments(args)
ftype_map = { ftype_map = {
"f32": gguf.LlamaFileType.ALL_F32, "f32": gguf.LlamaFileType.ALL_F32,

View file

@ -4,7 +4,6 @@ import os
from enum import IntEnum from enum import IntEnum
from typing import TYPE_CHECKING, Any, Sequence from typing import TYPE_CHECKING, Any, Sequence
from argparse import Namespace from argparse import Namespace
from math import ceil
from collections import deque from collections import deque
import numpy as np import numpy as np
@ -21,14 +20,15 @@ from .gguf_writer import GGUFWriter, WriterState
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
METADATA_ONLY_INDICATOR = -1
LLM_KV_SPLIT_NO = "split.no" LLM_KV_SPLIT_NO = "split.no"
LLM_KV_SPLIT_COUNT = "split.count" LLM_KV_SPLIT_COUNT = "split.count"
LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count" LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"
SplitTensorsPerFile: TypeAlias = deque[tuple[os.PathLike[str], deque[tuple[str, Any]], GGUFWriter]] # [(outfile name, [(tensor name, tensor data)] for each tensor in file, filewriter)]
KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)} KVTempData: TypeAlias = dict[str, tuple[Any, GGUFValueType]] # {key: (value, type)}
TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType] # (tensor name, tensor data, tensor dtype) TensorTempData: TypeAlias = tuple[str, np.ndarray[Any, Any], GGMLQuantizationType] # (tensor name, tensor data, tensor dtype)
Shard: TypeAlias = list[os.PathLike[str], int, int, deque[TensorTempData]] # [shard filename, shard tensor count, shard size, [tensor data]]
class SplitStyle(IntEnum): class SplitStyle(IntEnum):
@ -38,56 +38,182 @@ class SplitStyle(IntEnum):
class SplitArguments: class SplitArguments:
def __init__(self, args: Namespace = None) -> None: def __init__(self, args: Namespace) -> None:
self.split = args.split if args else False self.split = args.split
self.split_max_tensors = args.split_max_tensors if args else 0 self.split_max_tensors = args.split_max_tensors if args.split else 0
self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args and args.split_max_size else 0 self.split_max_size = GGUFManager.split_str_to_n_bytes(args.split_max_size) if args.split and args.split_max_size else 0
self.dry_run = args.dry_run if args else False self.split_style = SplitStyle.NONE if not self.split \
self.small_first_shard = not args.large_first_shard if args else False
self.split_style = SplitStyle.NONE if not self.split or not args \
else SplitStyle.TENSORS if self.split_max_tensors \ else SplitStyle.TENSORS if self.split_max_tensors \
else SplitStyle.SIZE else SplitStyle.SIZE
self.dry_run = args.dry_run
self.small_first_shard = args.small_first_shard
class SplitStrategy(deque): class GGUFManager(GGUFWriter):
data: SplitTensorsPerFile kv_data: KVTempData
tensors: list[TensorTempData]
split_arguments: SplitArguments
shards: list[Shard]
shard_writers: list[GGUFWriter]
def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str, def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
): ) -> None:
super().__init__() # we intentionally don't call superclass constructor
self.arch = arch
self.path = path
self.endianess = endianess
self.kv_data = {}
self.shards = []
self.shard_writers = []
self.total_tensors = 0
self.use_temp_file = use_temp_file
self.split_arguments = split_arguments
self.recent_key = None
self.state = WriterState.EMPTY
if split_arguments.split_style == SplitStyle.NONE: if self.split_arguments.small_first_shard:
self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess))) self.shards.append(["", 0, METADATA_ONLY_INDICATOR, None])
elif split_arguments.split_style == SplitStyle.TENSORS: def init_shards(self) -> None:
total_shards = ceil(len(model) / split_arguments.split_max_tensors) + split_arguments.small_first_shard self.total_tensors = sum(shard[1] for shard in self.shards)
shard_files = [fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, total_shards)) for i in range(total_shards)] total_size = sum(shard[2] for shard in self.shards)
if split_arguments.small_first_shard: # check if we need to split
self.append((shard_files[0], None, GGUFWriter(shard_files[0], arch, use_temp_file=use_temp_file, endianess=endianess))) if self.split_arguments.split_max_tensors and self.total_tensors < self.split_arguments.split_max_tensors:
print("Model has fewer tensors than the split threshold, not splitting")
self.split_style = SplitStyle.NONE
for i, shard in enumerate(shard_files[split_arguments.small_first_shard:]): if self.split_arguments.split_max_size and total_size < self.split_arguments.split_max_size:
start = i * split_arguments.split_max_tensors print("Model has smaller size than the split threshold, not splitting")
stop = min((i + 1) * split_arguments.split_max_tensors, len(model)) self.split_style = SplitStyle.NONE
self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess)))
elif split_arguments.split_style == SplitStyle.SIZE: # no shards are created when writing vocab so make one
shards = [[model[0]]] if not self.shards:
self.shards.append(["", 0, METADATA_ONLY_INDICATOR, None])
# we have to determine the shards first to determine how many shards there will be in total - two passes # format shard names
for i, shard in enumerate(model[1:]): if len(self.shards) == 1:
if SplitStrategy.get_tensor_size(shard[1]) + sum(SplitStrategy.get_tensor_size(t[1]) for t in shards[-1]) > split_arguments.split_max_size: self.shards[0][0] = self.path
shards.append([shard]) else:
else: for i in range(len(self.shards)):
shards[-1].append(shard) self.shards[i][0] = self.path.with_name(SHARD_NAME_FORMAT.format(self.path.stem, i + 1, len(self.shards)))
if split_arguments.small_first_shard: # print shard info
shards.insert(0, None) print("\nWriting the following files:")
for (path, tensor_ct, size, _) in self.shards:
print(f" {path}: n_tensors = {tensor_ct}, total_size = {GGUFManager.format_n_bytes_to_str(size)}")
print()
for i, shard in enumerate(shards): if self.split_arguments.dry_run:
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, len(shards))) print("\nDry run, not writing files")
self.append((outname, shard, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess))) exit()
# we don't want to initialize GGUFWriters until now because they create files
for i, (path, _, _, tensors) in enumerate(self.shards):
# dont_add_architecture is used for consistency - examples/gguf_split doesn't add arch to all shards
writer = GGUFWriter(path, self.arch, use_temp_file=self.use_temp_file,
endianess=self.endianess, dont_add_architecture=not (i == 0))
# only the first shard needs all the KV data
if i == 0:
for key, (value, etype) in self.kv_data.items():
writer.add_key(key)
writer.add_val(value, etype)
# add split metadata unless it's one file - small first shard splits even with SplitStyle.NONE
if self.split_arguments.split_style != SplitStyle.NONE or self.split_arguments.small_first_shard:
writer.add_uint16(LLM_KV_SPLIT_NO, i)
writer.add_uint16(LLM_KV_SPLIT_COUNT, len(self.shards))
writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors)
# add tensors, deque popleft() ensures references to eager tensors are not kept
while True:
try:
(name, tensor, dtype) = tensors.popleft()
writer.add_tensor(name, tensor, raw_dtype=dtype)
except:
break
self.shard_writers.append(writer)
def write_header_to_file(self) -> None:
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected GGUFManager state to be EMPTY, got {self.state}')
for writer in self.shard_writers:
writer.write_header_to_file()
self.state = WriterState.HEADER
def write_kv_data_to_file(self) -> None:
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected GGUFManager state to be HEADER, got {self.state}')
for writer in self.shard_writers:
writer.write_kv_data_to_file()
self.state = WriterState.KV_DATA
def write_tensors_to_file(self, progress: bool = False) -> None:
if self.state is not WriterState.KV_DATA:
raise ValueError(f'Expected GGUFManager state to be KV_DATA, got {self.state}')
running_total = self.total_tensors
for i in range(len(self.shard_writers)):
writer = self.shard_writers[i]
is_metadata = writer.ti_data_count == 0
if is_metadata:
print(f"Writing to shard {i + 1}/{len(self.shards)} with metadata only")
else:
print(f"Writing to shard {i + 1}/{len(self.shards)} with {writer.ti_data_count}/{running_total} remaining tensors (of {self.total_tensors} total)")
running_total -= writer.ti_data_count
writer.write_tensors_to_file(progress=(progress and not is_metadata))
del writer
self.state = WriterState.TI_DATA
# override add_key, add_val to handle kv data separately
def add_key(self, key: str) -> None:
self.recent_key = key
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
if self.recent_key is None:
raise ValueError("No key set for value")
self.kv_data[self.recent_key] = (val, vtype)
# need to handle arrays separately
def add_array(self, key: str, val: Sequence[Any]) -> None:
if not isinstance(val, Sequence):
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
self.kv_data[key] = (val, GGUFValueType.ARRAY)
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
# we build splits as tensors are added so we need logic to figure out when to split
# logic is all in the conditional because it short-circuits, otherwise accessing self.shards[-1] would throw an error
# create a first shard to start it off
if (len(self.shards) == self.split_arguments.small_first_shard \
# or split when over tensor limit
or (self.split_arguments.split_style == SplitStyle.TENSORS \
and self.shards[-1][1] >= self.split_arguments.split_max_tensors) \
# or split when over size limit
or (self.split_arguments.split_style == SplitStyle.SIZE \
and self.shards[-1][2] + GGUFManager.get_tensor_size(tensor) > self.split_arguments.split_max_size)):
# we fill in the name later when we know how many shards there are
self.shards.append(["", 1, GGUFManager.get_tensor_size(tensor), deque([(name, tensor, raw_dtype)])])
else:
self.shards[-1][1] += 1
self.shards[-1][2] += GGUFManager.get_tensor_size(tensor)
self.shards[-1][3].append((name, tensor, raw_dtype))
def close(self) -> None:
for writer in self.shard_writers:
writer.close()
@staticmethod @staticmethod
def get_tensor_size(tensor) -> int: def get_tensor_size(tensor) -> int:
@ -118,142 +244,11 @@ class SplitStrategy(deque):
@staticmethod @staticmethod
def format_n_bytes_to_str(num: int) -> str: def format_n_bytes_to_str(num: int) -> str:
if num == METADATA_ONLY_INDICATOR:
return "negligible - metadata only"
num = float(num) num = float(num)
for unit in ("", "K", "M", "G"): for unit in ("", "K", "M", "G"):
if abs(num) < 1024.0: if abs(num) < 1024.0:
return f"{num:3.1f}{unit}" return f"{num:3.1f}{unit}"
num /= 1024.0 num /= 1024.0
return f"{num:.1f}T - over 1TB, --split recommended" return f"{num:.1f}T - over 1TB, --split recommended"
class GGUFManager(GGUFWriter):
kv_data: KVTempData
tensors: list[TensorTempData]
split_arguments: SplitArguments
split_strategy: SplitStrategy
def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
) -> None:
# we intentionally don't call superclass constructor
self.arch = arch
self.path = path
self.endianess = endianess
self.kv_data = {}
self.tensors = []
self.split_strategy = None
self.total_shards = 0
self.total_tensors = 0
self.use_temp_file = use_temp_file
self.split_arguments = split_arguments
self.recent_key = None
self.state = WriterState.EMPTY
self.add_architecture()
def write_header_to_file(self) -> None:
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected GGUFManager state to be EMPTY, got {self.state}')
self.total_tensors = len(self.tensors)
total_size = sum(SplitStrategy.get_tensor_size(tensor[1]) for tensor in self.tensors)
if self.split_arguments.split_max_tensors and self.total_tensors < self.split_arguments.split_max_tensors:
print("Model has fewer tensors than the split threshold, not splitting")
self.split_style = SplitStyle.NONE
if self.split_arguments.split_max_size and total_size < self.split_arguments.split_max_size:
print("Model has smaller size than the split threshold, not splitting")
self.split_style = SplitStyle.NONE
self.split_strategy = SplitStrategy(self.path, self.tensors, self.arch, self.split_arguments,
use_temp_file=self.use_temp_file, endianess=self.endianess)
del self.tensors
self.total_shards = len(self.split_strategy)
print("\nWriting the following files:")
for (shard_path, shard_tensors, _) in self.split_strategy:
size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}")
if self.split_arguments.dry_run:
print("\nDry run, not writing files")
# instantiating GGUFWriters creates files
for name, _, _ in self.split_strategy:
os.remove(name)
return
self.state = WriterState.HEADER
def write_kv_data_to_file(self) -> None:
if self.split_arguments.dry_run:
return
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected GGUFManager state to be HEADER, got {self.state}')
# only the first shard needs all the KV data
for key, (value, etype) in self.kv_data.items():
self.split_strategy[0][2].add_key(key)
self.split_strategy[0][2].add_val(value, etype)
# the other shards need shard data
if self.split_arguments.split_style != SplitStyle.NONE:
for i, (_, _, writer) in enumerate(self.split_strategy):
writer.add_uint16(LLM_KV_SPLIT_NO, i)
writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards)
writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors)
self.state = WriterState.KV_DATA
def write_tensors_to_file(self, progress: bool = False) -> None:
if self.split_arguments.dry_run:
return
if self.state is not WriterState.KV_DATA:
raise ValueError(f'Expected GGUFManager state to be KV_DATA, got {self.state}')
running_total = self.total_tensors
for ct in range(self.total_shards):
(_, tensors, writer) = self.split_strategy.popleft()
tensors = deque(tensors) if tensors else None
shard_num_tensors = len(tensors) if tensors else 0
print(f"Writing to shard {ct}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)")
running_total -= shard_num_tensors
for _ in range(shard_num_tensors):
(name, tensor, dtype) = tensors.popleft()
writer.add_tensor(name, tensor, raw_dtype=dtype)
# need to write everything down here
writer.write_header_to_file()
writer.write_kv_data_to_file()
writer.write_tensors_to_file(progress=progress)
del tensors
self.state = WriterState.TI_DATA
# override add_key, add_val to handle kv data separately
def add_key(self, key: str) -> None:
self.recent_key = key
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
if self.recent_key is None:
raise ValueError("No key set for value")
self.kv_data[self.recent_key] = (val, vtype)
# need to handle arrays separately
def add_array(self, key: str, val: Sequence[Any]) -> None:
if not isinstance(val, Sequence):
raise ValueError(f'Expected a sequence for {key}, got {type(val)}')
self.kv_data[key] = (val, GGUFValueType.ARRAY)
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
self.tensors.append((name, tensor, raw_dtype))
def close(self) -> None:
for _, _, writer in self.split_strategy:
writer.close()

View file

@ -57,7 +57,7 @@ class GGUFWriter:
def __init__( def __init__(
self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True, self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True,
endianess: GGUFEndian = GGUFEndian.LITTLE, endianess: GGUFEndian = GGUFEndian.LITTLE, dont_add_architecture: bool = False
): ):
self.fout = open(path, "wb") self.fout = open(path, "wb")
self.arch = arch self.arch = arch
@ -77,7 +77,8 @@ class GGUFWriter:
)) ))
self.state = WriterState.EMPTY self.state = WriterState.EMPTY
self.add_architecture() if not dont_add_architecture:
self.add_architecture()
def write_header_to_file(self) -> None: def write_header_to_file(self) -> None:
if self.state is not WriterState.EMPTY: if self.state is not WriterState.EMPTY: