From f09674fbbd9d79d65ca6cf531b242128c527acd9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 3 May 2024 22:00:05 -0400 Subject: [PATCH] convert-hf : save memory with lazy evaluation --- convert-hf-to-gguf.py | 135 ++++++++++++++++++++++++++++++++++-- gguf-py/gguf/gguf_writer.py | 71 +++++++++++++++++-- 2 files changed, 196 insertions(+), 10 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index bbefd46f6..51e191616 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -12,7 +12,7 @@ import sys from enum import IntEnum from pathlib import Path from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload import numpy as np import torch @@ -63,7 +63,7 @@ class Model: # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool): + def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool): if self.__class__ == Model: raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated") self.dir_model = dir_model @@ -81,6 +81,9 @@ class Model: self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensors = dict(self.get_tensors()) + if not eager: + for k, v in self.tensors.items(): + self.tensors[k] = LazyTorchTensor.from_eager(v) @classmethod def __init_subclass__(cls): @@ -245,9 +248,11 @@ class Model: def write(self): self.write_tensors() + self.tensors.clear() # save memory by not keeping references to the tensors + self.gguf_writer.write_header_to_file() self.gguf_writer.write_kv_data_to_file() - self.gguf_writer.write_tensors_to_file() + self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.close() def write_vocab(self): @@ -2229,6 +2234,124 @@ class OlmoModel(Model): ###### CONVERSION LOGIC ###### +# tree of lazy tensors +class LazyTorchTensor: + _meta: Tensor + _data: Tensor | None + _args: list[Any] + _func: Callable[[list[Any]], Tensor] | None = None + + def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: list[Any] | None = None, func: Callable[[list[Any]], Tensor] | None = None): + self._meta = meta + self._data = data + self._args = args if args is not None else [] + self._func = func + + @staticmethod + def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any: + # TODO: dicts + if isinstance(o, (list, tuple)): + l = [] + for item in o: + l.append(LazyTorchTensor._recurse_apply(item, fn)) + if isinstance(o, tuple): + l = tuple(l) + return l + elif isinstance(o, LazyTorchTensor): + return fn(o) + else: + return o + + def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]: + def wrapped_fn(*args, **kwargs): + if kwargs is None: + kwargs = {} + args_list = ([self] if use_self else []) + list(args) + + meta_args = LazyTorchTensor._recurse_apply(args_list, lambda t: t._meta) + + return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args_list, func=lambda a: fn(*a, **kwargs)) + return wrapped_fn + + def __getattr__(self, __name: str) -> Any: + meta_attr = getattr(self._meta, __name) + if not callable(meta_attr): + return meta_attr + else: + return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True) + + _dtype_map: dict[torch.dtype, type] = { + torch.float16: np.float16, + torch.float32: np.float32, + } + + def numpy(self) -> gguf.LazyTensor: + dtype = self._dtype_map[self.dtype] + return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape) + + @overload + @staticmethod + def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ... + + @overload + @staticmethod + def to_eager(t: list[Tensor | LazyTorchTensor]) -> list[Tensor]: ... + + @staticmethod + def to_eager(t: Any) -> Any: + def simple_to_eager(_t: LazyTorchTensor) -> Tensor: + # wake up the lazy tensor + if _t._data is None and _t._func is not None: + # recurse into its arguments + _t._args = LazyTorchTensor.to_eager(_t._args) + _t._data = _t._func(_t._args) + if _t._data is not None: + return _t._data + else: + raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}") + + # recurse into lists and/or tuples, keeping their structure + return LazyTorchTensor._recurse_apply(t, simple_to_eager) + + @staticmethod + def from_eager(t: Tensor) -> Tensor: + if (t.__class__ == LazyTorchTensor): + return t + return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + del types # unused + + if kwargs is None: + kwargs = {} + + if func is torch.Tensor.numpy: + return args[0].numpy() + if func is torch.equal: + eager_args = LazyTorchTensor.to_eager(args) + return func(*eager_args, **kwargs) + + return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs) + + # special methods bypass __getattr__, so they need to be added manually + # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup + # NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used + # as self._meta is currently used), because then the following + # operations would by default not be wrapped, and so not propagated + # when the tensor is made eager. + # It's better to get non-silent errors for not-yet-supported operators. + # TODO: add more when needed to avoid clutter, or find a more concise way + def __neg__(self, *args): # mamba + return self._wrap_fn(torch.Tensor.__neg__)(self, *args) + + def __add__(self, *args): # gemma + return self._wrap_fn(torch.Tensor.__add__)(self, *args) + + def __getitem__(self, *args): # bloom falcon internlm2 + return self._wrap_fn(torch.Tensor.__getitem__)(self, *args) + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Convert a huggingface model to a GGML compatible file") @@ -2260,6 +2383,10 @@ def parse_args() -> argparse.Namespace: "--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)", ) + parser.add_argument( + "--no-lazy", action="store_true", + help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", + ) parser.add_argument( "--model-name", type=str, default=None, help="name of the model", @@ -2313,7 +2440,7 @@ def main() -> None: with torch.inference_mode(): model_class = Model.from_model_architecture(hparams["architectures"][0]) - model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file) + model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy) logger.info("Set model parameters") model_instance.set_gguf_parameters() diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 469aed8ef..8dcf9330b 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -7,7 +7,7 @@ import struct import tempfile from enum import Enum, auto from io import BufferedWriter -from typing import IO, Any, Sequence, Mapping +from typing import IO, Any, Callable, Sequence, Mapping from string import ascii_letters, digits import numpy as np @@ -28,6 +28,47 @@ from .constants import ( logger = logging.getLogger(__name__) +class LazyTensor: + data: Callable[[], np.ndarray[Any, Any]] + # to avoid too deep recursion + functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]] + dtype: np.dtype[Any] + shape: tuple[int, ...] + + def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]): + self.data = data + self.functions = [] + self.dtype = np.dtype(dtype) + self.shape = shape + + def astype(self, dtype: type, **kwargs) -> LazyTensor: + self.functions.append(lambda n: n.astype(dtype, **kwargs)) + self.dtype = np.dtype(dtype) + return self + + @property + def nbytes(self) -> int: + size = 1 + for n in self.shape: + size *= n + return size * self.dtype.itemsize + + def tofile(self, *args, **kwargs) -> None: + data = self.data() + for f in self.functions: + data = f(data) + assert data.shape == self.shape + assert data.dtype == self.dtype + assert data.nbytes == self.nbytes + self.functions = [] + self.data = lambda: data + data.tofile(*args, **kwargs) + + def byteswap(self, *args, **kwargs) -> LazyTensor: + self.functions.append(lambda n: n.byteswap(*args, **kwargs)) + return self + + class WriterState(Enum): EMPTY = auto() HEADER = auto() @@ -38,7 +79,7 @@ class WriterState(Enum): class GGUFWriter: fout: BufferedWriter temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: list[np.ndarray[Any, Any]] + tensors: list[np.ndarray[Any, Any] | LazyTensor] _simple_value_packing = { GGUFValueType.UINT8: "B", GGUFValueType.INT8: "b", @@ -237,7 +278,7 @@ class GGUFWriter: self.ti_data_count += 1 def add_tensor( - self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, + self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, ) -> None: if self.endianess == GGUFEndian.BIG: @@ -262,7 +303,7 @@ class GGUFWriter: if pad != 0: fp.write(bytes([0] * pad)) - def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: + def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None: if self.state is not WriterState.TI_DATA: raise ValueError(f'Expected output file to contain tensor info, got {self.state}') @@ -272,15 +313,33 @@ class GGUFWriter: tensor.tofile(self.fout) self.write_padding(self.fout, tensor.nbytes) - def write_tensors_to_file(self) -> None: + def write_tensors_to_file(self, *, progress: bool = False) -> None: self.write_ti_data_to_file() self.write_padding(self.fout, self.fout.tell()) if self.temp_file is None: + self.tensors.reverse() # to pop from the "beginning" in constant time + + if progress: + from tqdm import tqdm + + total_bytes = sum(t.nbytes for t in self.tensors) + + bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + + while True: + try: + tensor = self.tensors.pop() + except IndexError: + break + tensor.tofile(self.fout) + bar.update(tensor.nbytes) + self.write_padding(self.fout, tensor.nbytes) + return while True: try: - tensor = self.tensors.pop(0) + tensor = self.tensors.pop() except IndexError: break tensor.tofile(self.fout)