From f54daa07358bc55f1d5022df2ae62c9fa8b804cd Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Wed, 23 Aug 2023 15:42:19 -0600 Subject: [PATCH] Allow convert.py to convert to q8_0 Fix issue with bounded_parallel_map and greedy consuming iterator Display elapsed time during conversion --- convert.py | 117 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 91 insertions(+), 26 deletions(-) diff --git a/convert.py b/convert.py index b7c626d84..8a888849f 100755 --- a/convert.py +++ b/convert.py @@ -3,6 +3,7 @@ import gguf import argparse import concurrent.futures +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import copy import enum import faulthandler @@ -17,6 +18,7 @@ import re import signal import struct import sys +import time import zipfile import numpy as np @@ -50,7 +52,13 @@ DT_F32 = UnquantizedDataType('F32') DT_I32 = UnquantizedDataType('I32') DT_BF16 = UnquantizedDataType('BF16') -DataType = Union[UnquantizedDataType] +@dataclass(frozen=True) +class QuantizedDataType: + name: str + +DT_Q8_0 = QuantizedDataType('Q8_0') + +DataType = Union[UnquantizedDataType, QuantizedDataType] DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DT_BF16: np.dtype(np.uint16), @@ -73,8 +81,9 @@ SAFETENSORS_DATA_TYPES: Dict[str, DataType] = { # TODO: rename to LLAMAFileType # TODO: move to `gguf.py` class GGMLFileType(enum.IntEnum): - AllF32 = 0 - MostlyF16 = 1 # except 1d tensors + AllF32 = 0 + MostlyF16 = 1 # except 1d tensors + MostlyQ8_0 = 7 # except 1d tensors def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType: if len(tensor.shape) == 1: @@ -84,6 +93,8 @@ class GGMLFileType(enum.IntEnum): return DT_F32 elif self == GGMLFileType.MostlyF16: return DT_F16 + elif self == GGMLFileType.MostlyQ8_0: + return DT_Q8_0 else: raise ValueError(self) @@ -391,7 +402,10 @@ class UnquantizedTensor(Tensor): self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] def astype(self, data_type: DataType) -> Tensor: - dtype = DATA_TYPE_TO_NUMPY[data_type] + if data_type == DT_Q8_0: + dtype = DATA_TYPE_TO_NUMPY[DT_F32] + else: + dtype = DATA_TYPE_TO_NUMPY[data_type] if self.data_type == DT_BF16: self.ndarray = bf16_to_fp32(self.ndarray) return UnquantizedTensor(self.ndarray.astype(dtype)) @@ -455,7 +469,7 @@ class LazyTensor: def load(self) -> Tensor: ret = self._load() - assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description) + assert ret.data_type == self.data_type or (self.data_type is DT_Q8_0 and ret.data_type is DT_F32), (self.data_type, ret.data_type, self.description) return ret def astype(self, data_type: DataType) -> 'LazyTensor': @@ -699,23 +713,32 @@ def lazy_load_file(path: Path) -> ModelPlus: In = TypeVar('In') Out = TypeVar('Out') -def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]: +def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]: '''Parallel map, but with backpressure. If the caller doesn't call `next` fast enough, this will stop calling `func` at some point rather than letting results pile up in memory. Specifically, there is a max of one output value buffered per thread.''' - with concurrent.futures.ThreadPoolExecutor() as executor: + iterable = iter(iterable) + with factory(max_workers = max_workers) as executor: futures: List[concurrent.futures.Future[Out]] = [] - items_rev = list(iterable)[::-1] - for i in range(min(concurrency, len(items_rev))): - futures.append(executor.submit(func, items_rev.pop())) - while futures: + done = False + for i in range(concurrency): + try: + nexti = next(iterable) + except StopIteration: + break + futures.append(executor.submit(func, nexti)) + while not done or futures: result = futures.pop(0).result() - if items_rev: - futures.append(executor.submit(func, items_rev.pop())) + while len(futures) < concurrency: + try: + nexti = next(iterable) + except StopIteration: + done = True + break + futures.append(executor.submit(func, nexti)) yield result - def check_vocab_size(params: Params, vocab: Vocab) -> None: if params.n_vocab != vocab.vocab_size: assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab) @@ -732,6 +755,22 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None: msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." raise Exception(msg) +#### Mini Q8_0 quantization in Python +QK8_0 = 32 +BLOCK_Q8_0 = np.dtype([('d', ' None: @@ -777,9 +816,16 @@ class OutputFile: n_elements = 1 for dim in tensor.shape: n_elements *= dim - data_type = DATA_TYPE_TO_NUMPY[tensor.data_type] - data_nbytes = n_elements * data_type.itemsize - self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes) + if tensor.data_type == DT_Q8_0: + assert n_elements > 0 and n_elements % QK8_0 == 0, f'Cannot quantize as Q8_0, {n_elements} not a multiple of block size {QK8_0}' + data_type= BLOCK_Q8_0 + raw_dtype = gguf.GGMLQuantizationType.Q8_0 + data_nbytes = n_elements + (n_elements // QK8_0) * 2 + else: + data_type = DATA_TYPE_TO_NUMPY[tensor.data_type] + data_nbytes = n_elements * data_type.itemsize + raw_dtype = None + self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype) def write_meta(self) -> None: self.gguf.write_header_to_file() @@ -805,7 +851,19 @@ class OutputFile: of.close() @staticmethod - def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None: + def do_item(item: Tuple[str, LazyTensor]) -> (DataType, NDArray): + name, lazy_tensor = item + tensor = lazy_tensor.load().to_ggml() + return (lazy_tensor.data_type, tensor.ndarray) + + @staticmethod + def maybe_do_quant(item: Tuple[DataType, NDArray]) -> NDArray: + if item[0] == DT_Q8_0: + return quantize_array_q8_0(item[1]) + return item[1] + + @staticmethod + def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out) @@ -821,16 +879,19 @@ class OutputFile: of.write_meta() of.write_tensor_info() - def do_item(item: Tuple[str, LazyTensor]) -> NDArray: - name, lazy_tensor = item - return lazy_tensor.load().to_ggml().ndarray - # tensor data - ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8) + ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = 8) + if ftype == GGMLFileType.MostlyQ8_0: + ndarrays = bounded_parallel_map(OutputFile.maybe_do_quant, ndarrays, concurrency = 8, max_workers = 8, factory = ProcessPoolExecutor) + else: + ndarrays = map(OutputFile.maybe_do_quant, ndarrays) + + start = time.time() for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + elapsed = time.time() - start size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) - print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}") + print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:6} | T+{int(elapsed):4}") of.gguf.write_tensor_data(ndarray) of.close() @@ -842,6 +903,8 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi return GGMLFileType.AllF32 if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)): return GGMLFileType.MostlyF16 + if output_type_str == "q8_0": + return GGMLFileType.MostlyQ8_0 name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} @@ -993,6 +1056,7 @@ def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path: namestr = { GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", + GGMLFileType.MostlyQ8_0:"q8_0", }[file_type] ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" if ret in model_paths: @@ -1016,7 +1080,7 @@ def main(args_in: Optional[List[str]] = None) -> None: parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") - parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") + parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format (default: based on input)") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") @@ -1043,6 +1107,7 @@ def main(args_in: Optional[List[str]] = None) -> None: params.ftype = { "f32": GGMLFileType.AllF32, "f16": GGMLFileType.MostlyF16, + "q8_0": GGMLFileType.MostlyQ8_0, }[args.outtype] print(f"params = {params}") @@ -1074,7 +1139,7 @@ def main(args_in: Optional[List[str]] = None) -> None: params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, params, model, vocab) + OutputFile.write_all(outfile, ftype, params, model, vocab) print(f"Wrote {outfile}")