From f6fd3ea4e9a0a68dadbbd3956778672b7735e2d5 Mon Sep 17 00:00:00 2001 From: Christian Zhou-Zheng Date: Wed, 5 Jun 2024 12:28:40 -0400 Subject: [PATCH] further simplify GGUFManager --- convert-hf-to-gguf-update.py | 7 --- convert-hf-to-gguf.py | 12 +++-- gguf-py/gguf/gguf_manager.py | 86 +++++++++++++++++++----------------- 3 files changed, 54 insertions(+), 51 deletions(-) diff --git a/convert-hf-to-gguf-update.py b/convert-hf-to-gguf-update.py index 8ea2d82e3..84b72348d 100755 --- a/convert-hf-to-gguf-update.py +++ b/convert-hf-to-gguf-update.py @@ -81,14 +81,7 @@ models = [ {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, -<<<<<<< Updated upstream {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, -======= - {"name": "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom-7b1", }, - {"name": "gptbigcode", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/gpt_bigcode-santacoder", }, - {"name": "phi2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, - {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B-Chat", }, ->>>>>>> Stashed changes ] diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e415692ba..4b3dfdd70 100644 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -60,7 +60,7 @@ class Model: tensor_map: gguf.TensorNameMap tensor_names: set[str] | None fname_out: Path - gguf_writer: gguf.GGUFManager + gguf_writer: gguf.GGUFWriter # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -329,11 +329,16 @@ class Model: def write(self): self.write_tensors() - self.gguf_writer.write_to_file() + self.gguf_writer.write_header_to_file() + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_ti_data_to_file() self.gguf_writer.close() def write_vocab(self): - self.gguf_writer.write_to_file(meta_only=True) + if self.gguf_writer.split_arguments.split: + raise ValueError('Splitting the vocabulary is not supported') + self.gguf_writer.write_header_to_file() + self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() @staticmethod @@ -1563,7 +1568,6 @@ class MiniCPMModel(Model): return [(self.map_tensor_name(name), data_torch)] -# TODO what the hell is this? @Model.register("QWenLMHeadModel") class QwenModel(Model): model_arch = gguf.MODEL_ARCH.QWEN diff --git a/gguf-py/gguf/gguf_manager.py b/gguf-py/gguf/gguf_manager.py index 13a2f0eea..aeec9642c 100644 --- a/gguf-py/gguf/gguf_manager.py +++ b/gguf-py/gguf/gguf_manager.py @@ -2,8 +2,7 @@ from __future__ import annotations import os from enum import IntEnum -from typing import TYPE_CHECKING, Any, Sequence, Mapping -from string import ascii_letters, digits +from typing import TYPE_CHECKING, Any, Sequence from argparse import Namespace from math import ceil from collections import deque @@ -18,7 +17,7 @@ from .constants import ( GGUFEndian, GGUFValueType ) -from .gguf_writer import GGUFWriter +from .gguf_writer import GGUFWriter, WriterState SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf" @@ -74,7 +73,7 @@ class SplitStrategy(deque): self.append((shard, model[start:stop], GGUFWriter(shard, arch, use_temp_file=use_temp_file, endianess=endianess))) elif split_arguments.split_style == SplitStyle.SIZE: - shards = deque() + shards = [] # we have to determine the shards first to determine how many shards there will be in total - two passes for i, shard in enumerate(model): @@ -135,7 +134,6 @@ class SplitStrategy(deque): num /= 1024.0 return f"{num:.1f}T - over 1TB, --split recommended" -# TODO fall back to normal GGUFWriter in convert-hf-to-gguf.py if no --split class GGUFManager(GGUFWriter): kv_data: KVTempData tensors: list[TensorTempData] @@ -145,27 +143,25 @@ class GGUFManager(GGUFWriter): def __init__(self, path: os.PathLike[str] | str, arch: str, split_arguments: SplitArguments, use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE ) -> None: - # TODO be able to use superclass constructor - # super().__init__(path, arch, use_temp_file=use_temp_file, endianess=endianess) + # we intentionally don't call superclass constructor self.arch = arch self.path = path self.endianess = endianess - self.offset_tensor = 0 self.kv_data = {} self.tensors = [] - # TODO how many of these do you need self.split_strategy = None - self.total_shards = None - self.total_tensors = None + self.total_shards = 0 + self.total_tensors = 0 self.use_temp_file = use_temp_file self.split_arguments = split_arguments self.recent_key = None + self.state = WriterState.EMPTY self.add_architecture() - # TODO split back into write_header_to_file, write_kv_data_to_file, write_ti_data_to_file - def write_to_file(self, meta_only: bool = False) -> None: + def write_header_to_file(self) -> None: + if self.state is not WriterState.EMPTY: + raise ValueError(f'Expected GGUFManager state to be EMPTY, got {self.state}') - # here is the first place you can assume you have all tensors written and you can establish the size of the file - so logic goes here self.total_tensors = len(self.tensors) total_size = sum(SplitStrategy.get_tensor_size(tensor[1]) for tensor in self.tensors) @@ -182,26 +178,6 @@ class GGUFManager(GGUFWriter): del self.tensors self.total_shards = len(self.split_strategy) - # only the first shard needs all the KV data - for key, (value, etype) in self.kv_data.items(): - self.split_strategy[0][2].add_key(key) - self.split_strategy[0][2].add_val(value, etype) - - if self.split_arguments.split_style != SplitStyle.NONE: - for i, (_, _, writer) in enumerate(self.split_strategy): - writer.add_uint16(LLM_KV_SPLIT_NO, i) - writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards) - writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors) - - # metadata/vocab only can write and return here - if meta_only: - for i, (_, _, writer) in enumerate(self.split_strategy): - writer.write_header_to_file() - writer.write_kv_data_to_file() - return - - # tensor writing code starts here - print("\nWriting the following files:") for (shard_path, shard_tensors, _) in self.split_strategy: size = SplitStrategy.format_n_bytes_to_str(sum(SplitStrategy.get_tensor_size(t[1]) for t in shard_tensors)) if shard_tensors else "negligible - metadata only" @@ -214,10 +190,38 @@ class GGUFManager(GGUFWriter): os.remove(name) return - # run add_tensor_info, write data, then write_tensor_data - taken from convert.py + self.state = WriterState.HEADER + + def write_kv_data_to_file(self) -> None: + if self.split_arguments.dry_run: + return + + if self.state is not WriterState.HEADER: + raise ValueError(f'Expected GGUFManager state to be HEADER, got {self.state}') + + # only the first shard needs all the KV data + for key, (value, etype) in self.kv_data.items(): + self.split_strategy[0][2].add_key(key) + self.split_strategy[0][2].add_val(value, etype) + + # the other shards need shard data + if self.split_arguments.split_style != SplitStyle.NONE: + for i, (_, _, writer) in enumerate(self.split_strategy): + writer.add_uint16(LLM_KV_SPLIT_NO, i) + writer.add_uint16(LLM_KV_SPLIT_COUNT, self.total_shards) + writer.add_int32(LLM_KV_SPLIT_TENSORS_COUNT, self.total_tensors) + + self.state = WriterState.KV_DATA + + def write_ti_data_to_file(self) -> None: + if self.split_arguments.dry_run: + return + + if self.state is not WriterState.KV_DATA: + raise ValueError(f'Expected GGUFManager state to be KV_DATA, got {self.state}') + running_total = self.total_tensors - ct = 0 - while True: + for ct in range(self.total_shards): try: (_, tensors, writer) = self.split_strategy.popleft() tensors = deque(tensors) if tensors else None @@ -234,15 +238,17 @@ class GGUFManager(GGUFWriter): break writer.add_tensor(name, tensor, raw_dtype=dtype) - print(f"Writing to shard {ct + 1}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)") - running_total -= shard_num_tensors + print(f"Writing to shard {ct}/{self.total_shards} with {shard_num_tensors}/{running_total} remaining tensors (of {self.total_tensors} total)") + running_total -= shard_num_tensors + # need to write everything down here writer.write_header_to_file() writer.write_kv_data_to_file() writer.write_tensors_to_file(progress=True) - ct = ct + 1 del tensors + self.state = WriterState.TI_DATA + # override add_key, add_val to handle kv data separately def add_key(self, key: str) -> None: self.recent_key = key