resolve merge + SplitArguments for easier parsing

This commit is contained in:
Christian Zhou-Zheng 2024-05-09 21:22:55 -04:00
parent 14b32915b9
commit 87a98a5b6d
3 changed files with 73 additions and 49 deletions

View file

@ -58,7 +58,7 @@ class Model:
part_names: list[str]
is_safetensors: bool
hparams: dict[str, Any]
gguf_writer: gguf.GGUFWriter
gguf_writer: gguf.GGUFManager
block_count: int
tensor_map: gguf.TensorNameMap
tensor_names: set[str] | None
@ -66,7 +66,8 @@ class Model:
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool,
split_arguments: gguf.SplitArguments):
if self.__class__ == Model:
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
self.dir_model = dir_model
@ -83,7 +84,8 @@ class Model:
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
self.hparams = Model.load_hparams(self.dir_model)
self.gguf_writer = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file)
self.gguf_writer = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], split_arguments,
endianess=self.endianess, use_temp_file=self.use_temp_file)
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensor_names = None
@ -275,9 +277,7 @@ class Model:
def write(self):
self.write_tensors()
self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.write_to_file()
self.gguf_writer.close()
def write_vocab(self):
@ -2501,8 +2501,7 @@ def main() -> None:
if args.split_max_tensors and args.split_max_size:
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
if args.split_max_size:
args.split_max_size = gguf.SplitStrategy.split_str_to_n_bytes(args.split_max_size)
split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments()
ftype_map = {
"f32": gguf.GGMLQuantizationType.F32,
@ -2521,7 +2520,8 @@ def main() -> None:
with torch.inference_mode():
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy)
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file,
args.no_lazy, split_arguments)
logger.info("Set model parameters")
model_instance.set_gguf_parameters()

View file

@ -1065,8 +1065,8 @@ def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False)
class OutputFile:
def __init__(self, fname_out: Path, args: argparse.Namespace, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], args, endianess=endianess)
def __init__(self, fname_out: Path, split_arguments: gguf.SplitArguments, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE):
self.gguf = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], split_arguments, endianess=endianess)
def add_meta_arch(self, params: Params) -> None:
name = "LLaMA"
@ -1183,7 +1183,7 @@ class OutputFile:
) -> None:
check_vocab_size(params, vocab, pad_vocab=pad_vocab)
of = OutputFile(fname_out, endianess=endianess)
of = OutputFile(fname_out, gguf.SplitArguments(), endianess=endianess)
# meta data
of.add_meta_arch(params)
@ -1210,10 +1210,10 @@ class OutputFile:
@staticmethod
def write_all(
fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab,
args: argparse.Namespace, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE
split_arguments: gguf.SplitArguments, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE
) -> None:
check_vocab_size(params, vocab, pad_vocab=args.pad_vocab)
of = OutputFile(fname_out, args, endianess=endianess)
of = OutputFile(fname_out, split_arguments, endianess=endianess)
# meta data
of.add_meta_arch(params)
@ -1500,8 +1500,7 @@ def main(args_in: list[str] | None = None) -> None:
if args.split_max_tensors and args.split_max_size:
raise ValueError("Can't specify both --split-max-tensors and --split-max-size")
if args.split_max_size:
args.split_max_size = gguf.SplitStrategy.split_str_to_n_bytes(args.split_max_size)
split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments()
if not args.vocab_only:
model_plus = load_some_model(args.model)
@ -1578,15 +1577,9 @@ def main(args_in: list[str] | None = None) -> None:
params.ftype = ftype
print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, args,
concurrency=args.concurrency, endianess=endianess)
if not args.dry_run:
print(f"Wrote {outfile}")
logger.info(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab,
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, split_arguments,
concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab)
if not args.dry_run:
logger.info(f"Wrote {outfile}")

View file

@ -45,30 +45,57 @@ class SplitStyle(IntEnum):
SIZE = 2
class SplitArguments:
split: bool
dry_run: bool
small_first_shard: bool
split_max_tensors: int
split_max_size: int
split_style: SplitStyle
def __init__(self) -> None:
self.split = False
self.dry_run = False
self.small_first_shard = False
self.split_max_tensors = 0
self.split_max_size = 0
self.split_style = SplitStyle.NONE
def __init__(self, args: Namespace) -> None:
self.split = args.split
self.split_max_tensors = args.split_max_tensors
self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else None
self.dry_run = args.dry_run
self.small_first_shard = not args.large_first_shard
self.split_style = SplitStyle.NONE if not self.split \
else SplitStyle.TENSORS if self.split_max_tensors \
else SplitStyle.SIZE
class SplitStrategy:
data: SplitTensorsPerFile
def __init__(self, split_style: SplitStyle, fname_out: os.PathLike[str], model: list[TensorTempData],
args: Namespace, arch: str, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE, small_first_shard: bool = True
def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str,
split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE,
):
self.data = []
if split_style == SplitStyle.NONE:
if split_arguments.split_style == SplitStyle.NONE:
self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess)))
elif split_style == SplitStyle.TENSORS:
total_shards = ceil(len(model) / args.split_max_tensors) + small_first_shard
elif split_arguments.split_style == SplitStyle.TENSORS:
total_shards = ceil(len(model) / split_arguments.split_max_tensors) + split_arguments.small_first_shard
shard_files = [fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, total_shards)) for i in range(total_shards)]
if small_first_shard:
if split_arguments.small_first_shard:
self.append((shard_files[0], None, GGUFWriter(shard_files[0], arch, use_temp_file=use_temp_file, endianess=endianess)))
for i, shard in enumerate(shard_files[small_first_shard:]):
start = i * args.split_max_tensors
stop = min((i + 1) * args.split_max_tensors, len(model))
for i, shard in enumerate(shard_files[split_arguments.small_first_shard:]):
start = i * split_arguments.split_max_tensors
stop = min((i + 1) * split_arguments.split_max_tensors, len(model))
self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess)))
elif split_style == SplitStyle.SIZE:
elif split_arguments.split_style == SplitStyle.SIZE:
shards = []
# we have to determine the shards first to determine how many shards there will be in total - two passes
@ -76,15 +103,15 @@ class SplitStrategy:
if i == 0:
shards.append([shard])
continue
if SplitStrategy.get_tensor_size(shard[1]) + sum(SplitStrategy.get_tensor_size(t[1]) for t in shards[-1]) > args.split_max_size:
if SplitStrategy.get_tensor_size(shard[1]) + sum(SplitStrategy.get_tensor_size(t[1]) for t in shards[-1]) > split_arguments.split_max_size:
shards.append([shard])
else:
shards[-1].append(shard)
total_shards = len(shards) + small_first_shard
total_shards = len(shards) + split_arguments.small_first_shard
shard_offset = 1
if small_first_shard:
if split_arguments.small_first_shard:
outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, shard_offset, total_shards))
self.append((outname, None, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess)))
shard_offset += 1
@ -150,25 +177,23 @@ class SplitStrategy:
class GGUFManager:
kv_data: KVTempData
tensors: list[TensorTempData]
split_style: SplitStyle
split_arguments: SplitArguments
split_strategy: SplitStrategy
def __init__(self, path: os.PathLike[str] | str, arch: str, args: Namespace, use_temp_file: bool = True,
endianess: GGUFEndian = GGUFEndian.LITTLE) -> None:
def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments,
use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE
) -> None:
self.arch = arch
self.path = path
self.endianess = endianess
self.offset_tensor = 0
self.kv_data = {}
self.tensors = []
self.args = args
self.split_style = SplitStyle.NONE if not args.split \
else SplitStyle.TENSORS if args.split_max_tensors \
else SplitStyle.SIZE
self.split_strategy = None
self.total_shards = None
self.total_tensors = None
self.use_temp_file = use_temp_file
self.split_arguments = split_arguments
self.add_architecture()
@ -183,15 +208,16 @@ class GGUFManager:
self.total_tensors = len(self.tensors)
total_size = sum(SplitStrategy.get_tensor_size(tensor[1]) for tensor in self.tensors)
if self.args.split_max_tensors and self.total_tensors < self.args.split_max_tensors:
if self.split_arguments.split_max_tensors and self.total_tensors < self.split_arguments.split_max_tensors:
print("Model has fewer tensors than the split threshold, not splitting")
self.split_style = SplitStyle.NONE
if self.args.split_max_size and total_size < self.args.split_max_size:
if self.split_arguments.split_max_size and total_size < self.split_arguments.split_max_size:
print("Model has smaller size than the split threshold, not splitting")
self.split_style = SplitStyle.NONE
self.split_strategy = SplitStrategy(self.split_style, self.path, self.tensors, self.args, self.arch)
self.split_strategy = SplitStrategy(self.path, self.tensors, self.arch, self.split_arguments,
use_temp_file=self.use_temp_file, endianess=self.endianess)
self.total_shards = len(self.split_strategy)
# only the first shard needs all the KV data
@ -199,7 +225,7 @@ class GGUFManager:
self.split_strategy[0][2].add_key(key)
self.split_strategy[0][2].add_val(value, etype)
if self.split_style != SplitStyle.NONE:
if self.split_arguments.split_style != SplitStyle.NONE:
for i, (_, _, writer) in enumerate(self.split_strategy):
writer.add_uint16(LLM_KV_SPLIT_NO, i)
writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards)
@ -219,7 +245,7 @@ class GGUFManager:
size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only"
print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}")
if self.args.dry_run:
if self.split_arguments.dry_run:
print("\nDry run, not writing files")
# instantiating GGUFWriters creates files
for name, _, _ in self.split_strategy:
@ -232,6 +258,7 @@ class GGUFManager:
for i, (_, tensors, writer) in enumerate(self.split_strategy):
if tensors:
print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
for j, (name, tensor) in enumerate(tensors):
n_elements = int(np.prod(tensor.shape))
# logic from convert.py
@ -251,6 +278,7 @@ class GGUFManager:
f"[{j + 1:{padi}d}/{len(tensors)}] Writing tensor {name:38s} | size {size:16} | type {dtype:8} | T+{int(elapsed):4}"
)
writer.add_tensor(name, tensor)
print(f"Writing to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
writer.write_header_to_file()
@ -258,7 +286,7 @@ class GGUFManager:
writer.write_tensors_to_file()
if tensors:
print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)")
# TODO this shows up AFTER writing which we don't really want - move it
running_total -= len(tensors)
if write_tensor_data:
@ -473,6 +501,9 @@ class GGUFManager:
def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)
def add_tokenizer_pre(self, pre: str) -> None:
self.add_string(Keys.Tokenizer.PRE, pre)
def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
self.add_array(Keys.Tokenizer.LIST, tokens)