diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 8e8ecfe3d..4ba681473 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -58,7 +58,7 @@ class Model: part_names: list[str] is_safetensors: bool hparams: dict[str, Any] - gguf_writer: gguf.GGUFWriter + gguf_writer: gguf.GGUFManager block_count: int tensor_map: gguf.TensorNameMap tensor_names: set[str] | None @@ -66,7 +66,8 @@ class Model: # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool): + def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool, + split_arguments: gguf.SplitArguments): if self.__class__ == Model: raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated") self.dir_model = dir_model @@ -83,7 +84,8 @@ class Model: self.part_names = Model.get_model_part_names(self.dir_model, ".bin") self.hparams = Model.load_hparams(self.dir_model) - self.gguf_writer = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file) + self.gguf_writer = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], split_arguments, + endianess=self.endianess, use_temp_file=self.use_temp_file) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_names = None @@ -275,9 +277,7 @@ class Model: def write(self): self.write_tensors() - self.gguf_writer.write_header_to_file() - self.gguf_writer.write_kv_data_to_file() - self.gguf_writer.write_tensors_to_file(progress=True) + self.gguf_writer.write_to_file() self.gguf_writer.close() def write_vocab(self): @@ -2501,8 +2501,7 @@ def main() -> None: if args.split_max_tensors and args.split_max_size: raise ValueError("Can't specify both --split-max-tensors and --split-max-size") - if args.split_max_size: - args.split_max_size = gguf.SplitStrategy.split_str_to_n_bytes(args.split_max_size) + split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments() ftype_map = { "f32": gguf.GGMLQuantizationType.F32, @@ -2521,7 +2520,8 @@ def main() -> None: with torch.inference_mode(): model_class = Model.from_model_architecture(hparams["architectures"][0]) - model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, + args.no_lazy, split_arguments) logger.info("Set model parameters") model_instance.set_gguf_parameters() diff --git a/convert.py b/convert.py index cc133576f..9ee0f1ce7 100755 --- a/convert.py +++ b/convert.py @@ -1065,8 +1065,8 @@ def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) class OutputFile: - def __init__(self, fname_out: Path, args: argparse.Namespace, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): - self.gguf = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], args, endianess=endianess) + def __init__(self, fname_out: Path, split_arguments: gguf.SplitArguments, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): + self.gguf = gguf.GGUFManager(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], split_arguments, endianess=endianess) def add_meta_arch(self, params: Params) -> None: name = "LLaMA" @@ -1183,7 +1183,7 @@ class OutputFile: ) -> None: check_vocab_size(params, vocab, pad_vocab=pad_vocab) - of = OutputFile(fname_out, endianess=endianess) + of = OutputFile(fname_out, gguf.SplitArguments(), endianess=endianess) # meta data of.add_meta_arch(params) @@ -1210,10 +1210,10 @@ class OutputFile: @staticmethod def write_all( fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, - args: argparse.Namespace, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE + split_arguments: gguf.SplitArguments, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE ) -> None: check_vocab_size(params, vocab, pad_vocab=args.pad_vocab) - of = OutputFile(fname_out, args, endianess=endianess) + of = OutputFile(fname_out, split_arguments, endianess=endianess) # meta data of.add_meta_arch(params) @@ -1500,8 +1500,7 @@ def main(args_in: list[str] | None = None) -> None: if args.split_max_tensors and args.split_max_size: raise ValueError("Can't specify both --split-max-tensors and --split-max-size") - if args.split_max_size: - args.split_max_size = gguf.SplitStrategy.split_str_to_n_bytes(args.split_max_size) + split_arguments = gguf.SplitArguments(args) if args.split else gguf.SplitArguments() if not args.vocab_only: model_plus = load_some_model(args.model) @@ -1578,15 +1577,9 @@ def main(args_in: list[str] | None = None) -> None: params.ftype = ftype - print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, args, - concurrency=args.concurrency, endianess=endianess) - if not args.dry_run: - print(f"Wrote {outfile}") - logger.info(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, split_arguments, concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab) if not args.dry_run: logger.info(f"Wrote {outfile}") diff --git a/gguf-py/gguf/gguf_manager.py b/gguf-py/gguf/gguf_manager.py index 5963aa036..f36b0173e 100644 --- a/gguf-py/gguf/gguf_manager.py +++ b/gguf-py/gguf/gguf_manager.py @@ -45,30 +45,57 @@ class SplitStyle(IntEnum): SIZE = 2 +class SplitArguments: + split: bool + dry_run: bool + small_first_shard: bool + split_max_tensors: int + split_max_size: int + split_style: SplitStyle + + def __init__(self) -> None: + self.split = False + self.dry_run = False + self.small_first_shard = False + self.split_max_tensors = 0 + self.split_max_size = 0 + self.split_style = SplitStyle.NONE + + def __init__(self, args: Namespace) -> None: + self.split = args.split + self.split_max_tensors = args.split_max_tensors + self.split_max_size = SplitStrategy.split_str_to_n_bytes(args.split_max_size) if args.split_max_size else None + self.dry_run = args.dry_run + self.small_first_shard = not args.large_first_shard + self.split_style = SplitStyle.NONE if not self.split \ + else SplitStyle.TENSORS if self.split_max_tensors \ + else SplitStyle.SIZE + + class SplitStrategy: data: SplitTensorsPerFile - def __init__(self, split_style: SplitStyle, fname_out: os.PathLike[str], model: list[TensorTempData], - args: Namespace, arch: str, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE, small_first_shard: bool = True + def __init__(self, fname_out: os.PathLike[str], model: list[TensorTempData], arch: str, + split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE, ): self.data = [] - if split_style == SplitStyle.NONE: + if split_arguments.split_style == SplitStyle.NONE: self.append((fname_out, model, GGUFWriter(fname_out, arch, use_temp_file=use_temp_file, endianess=endianess))) - elif split_style == SplitStyle.TENSORS: - total_shards = ceil(len(model) / args.split_max_tensors) + small_first_shard + elif split_arguments.split_style == SplitStyle.TENSORS: + total_shards = ceil(len(model) / split_arguments.split_max_tensors) + split_arguments.small_first_shard shard_files = [fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, i + 1, total_shards)) for i in range(total_shards)] - if small_first_shard: + if split_arguments.small_first_shard: self.append((shard_files[0], None, GGUFWriter(shard_files[0], arch, use_temp_file=use_temp_file, endianess=endianess))) - for i, shard in enumerate(shard_files[small_first_shard:]): - start = i * args.split_max_tensors - stop = min((i + 1) * args.split_max_tensors, len(model)) + for i, shard in enumerate(shard_files[split_arguments.small_first_shard:]): + start = i * split_arguments.split_max_tensors + stop = min((i + 1) * split_arguments.split_max_tensors, len(model)) self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess))) - elif split_style == SplitStyle.SIZE: + elif split_arguments.split_style == SplitStyle.SIZE: shards = [] # we have to determine the shards first to determine how many shards there will be in total - two passes @@ -76,15 +103,15 @@ class SplitStrategy: if i == 0: shards.append([shard]) continue - if SplitStrategy.get_tensor_size(shard[1]) + sum(SplitStrategy.get_tensor_size(t[1]) for t in shards[-1]) > args.split_max_size: + if SplitStrategy.get_tensor_size(shard[1]) + sum(SplitStrategy.get_tensor_size(t[1]) for t in shards[-1]) > split_arguments.split_max_size: shards.append([shard]) else: shards[-1].append(shard) - total_shards = len(shards) + small_first_shard + total_shards = len(shards) + split_arguments.small_first_shard shard_offset = 1 - if small_first_shard: + if split_arguments.small_first_shard: outname = fname_out.with_name(SHARD_NAME_FORMAT.format(fname_out.stem, shard_offset, total_shards)) self.append((outname, None, GGUFWriter(outname, arch, use_temp_file=use_temp_file, endianess=endianess))) shard_offset += 1 @@ -150,25 +177,23 @@ class SplitStrategy: class GGUFManager: kv_data: KVTempData tensors: list[TensorTempData] - split_style: SplitStyle + split_arguments: SplitArguments split_strategy: SplitStrategy - def __init__(self, path: os.PathLike[str] | str, arch: str, args: Namespace, use_temp_file: bool = True, - endianess: GGUFEndian = GGUFEndian.LITTLE) -> None: + def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments, + use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE + ) -> None: self.arch = arch self.path = path self.endianess = endianess self.offset_tensor = 0 self.kv_data = {} self.tensors = [] - self.args = args - self.split_style = SplitStyle.NONE if not args.split \ - else SplitStyle.TENSORS if args.split_max_tensors \ - else SplitStyle.SIZE self.split_strategy = None self.total_shards = None self.total_tensors = None self.use_temp_file = use_temp_file + self.split_arguments = split_arguments self.add_architecture() @@ -183,15 +208,16 @@ class GGUFManager: self.total_tensors = len(self.tensors) total_size = sum(SplitStrategy.get_tensor_size(tensor[1]) for tensor in self.tensors) - if self.args.split_max_tensors and self.total_tensors < self.args.split_max_tensors: + if self.split_arguments.split_max_tensors and self.total_tensors < self.split_arguments.split_max_tensors: print("Model has fewer tensors than the split threshold, not splitting") self.split_style = SplitStyle.NONE - if self.args.split_max_size and total_size < self.args.split_max_size: + if self.split_arguments.split_max_size and total_size < self.split_arguments.split_max_size: print("Model has smaller size than the split threshold, not splitting") self.split_style = SplitStyle.NONE - self.split_strategy = SplitStrategy(self.split_style, self.path, self.tensors, self.args, self.arch) + self.split_strategy = SplitStrategy(self.path, self.tensors, self.arch, self.split_arguments, + use_temp_file=self.use_temp_file, endianess=self.endianess) self.total_shards = len(self.split_strategy) # only the first shard needs all the KV data @@ -199,7 +225,7 @@ class GGUFManager: self.split_strategy[0][2].add_key(key) self.split_strategy[0][2].add_val(value, etype) - if self.split_style != SplitStyle.NONE: + if self.split_arguments.split_style != SplitStyle.NONE: for i, (_, _, writer) in enumerate(self.split_strategy): writer.add_uint16(LLM_KV_SPLIT_NO, i) writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards) @@ -219,7 +245,7 @@ class GGUFManager: size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only" print(f" {shard_path}: n_tensors = {len(shard_tensors) if shard_tensors else 0}, total_size = {size}") - if self.args.dry_run: + if self.split_arguments.dry_run: print("\nDry run, not writing files") # instantiating GGUFWriters creates files for name, _, _ in self.split_strategy: @@ -232,6 +258,7 @@ class GGUFManager: for i, (_, tensors, writer) in enumerate(self.split_strategy): if tensors: + print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)") for j, (name, tensor) in enumerate(tensors): n_elements = int(np.prod(tensor.shape)) # logic from convert.py @@ -251,6 +278,7 @@ class GGUFManager: f"[{j + 1:{padi}d}/{len(tensors)}] Writing tensor {name:38s} | size {size:16} | type {dtype:8} | T+{int(elapsed):4}" ) writer.add_tensor(name, tensor) + print(f"Writing to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)") writer.write_header_to_file() @@ -258,7 +286,7 @@ class GGUFManager: writer.write_tensors_to_file() if tensors: - print(f"\nWriting to shard {i + 1}/{self.total_shards} with {len(tensors)}/{running_total} remaining tensors (of {self.total_tensors} total)") + # TODO this shows up AFTER writing which we don't really want - move it running_total -= len(tensors) if write_tensor_data: @@ -473,6 +501,9 @@ class GGUFManager: def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) + def add_tokenizer_pre(self, pre: str) -> None: + self.add_string(Keys.Tokenizer.PRE, pre) + def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: self.add_array(Keys.Tokenizer.LIST, tokens)