convert-hf : save memory with lazy evaluation

This commit is contained in:
Francis Couture-Harpin 2024-05-03 22:00:05 -04:00
parent 215a0d38c8
commit f09674fbbd
2 changed files with 196 additions and 10 deletions

View file

@ -12,7 +12,7 @@ import sys
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from hashlib import sha256 from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
import numpy as np import numpy as np
import torch import torch
@ -63,7 +63,7 @@ class Model:
# subclasses should define this! # subclasses should define this!
model_arch: gguf.MODEL_ARCH model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool): def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
if self.__class__ == Model: if self.__class__ == Model:
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated") raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
self.dir_model = dir_model self.dir_model = dir_model
@ -81,6 +81,9 @@ class Model:
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self.tensors = dict(self.get_tensors()) self.tensors = dict(self.get_tensors())
if not eager:
for k, v in self.tensors.items():
self.tensors[k] = LazyTorchTensor.from_eager(v)
@classmethod @classmethod
def __init_subclass__(cls): def __init_subclass__(cls):
@ -245,9 +248,11 @@ class Model:
def write(self): def write(self):
self.write_tensors() self.write_tensors()
self.tensors.clear() # save memory by not keeping references to the tensors
self.gguf_writer.write_header_to_file() self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file() self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file() self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close() self.gguf_writer.close()
def write_vocab(self): def write_vocab(self):
@ -2229,6 +2234,124 @@ class OlmoModel(Model):
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######
# tree of lazy tensors
class LazyTorchTensor:
_meta: Tensor
_data: Tensor | None
_args: list[Any]
_func: Callable[[list[Any]], Tensor] | None = None
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: list[Any] | None = None, func: Callable[[list[Any]], Tensor] | None = None):
self._meta = meta
self._data = data
self._args = args if args is not None else []
self._func = func
@staticmethod
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
# TODO: dicts
if isinstance(o, (list, tuple)):
l = []
for item in o:
l.append(LazyTorchTensor._recurse_apply(item, fn))
if isinstance(o, tuple):
l = tuple(l)
return l
elif isinstance(o, LazyTorchTensor):
return fn(o)
else:
return o
def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
args_list = ([self] if use_self else []) + list(args)
meta_args = LazyTorchTensor._recurse_apply(args_list, lambda t: t._meta)
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args_list, func=lambda a: fn(*a, **kwargs))
return wrapped_fn
def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
if not callable(meta_attr):
return meta_attr
else:
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
}
def numpy(self) -> gguf.LazyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
@overload
@staticmethod
def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
@overload
@staticmethod
def to_eager(t: list[Tensor | LazyTorchTensor]) -> list[Tensor]: ...
@staticmethod
def to_eager(t: Any) -> Any:
def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
# wake up the lazy tensor
if _t._data is None and _t._func is not None:
# recurse into its arguments
_t._args = LazyTorchTensor.to_eager(_t._args)
_t._data = _t._func(_t._args)
if _t._data is not None:
return _t._data
else:
raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
# recurse into lists and/or tuples, keeping their structure
return LazyTorchTensor._recurse_apply(t, simple_to_eager)
@staticmethod
def from_eager(t: Tensor) -> Tensor:
if (t.__class__ == LazyTorchTensor):
return t
return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
if kwargs is None:
kwargs = {}
if func is torch.Tensor.numpy:
return args[0].numpy()
if func is torch.equal:
eager_args = LazyTorchTensor.to_eager(args)
return func(*eager_args, **kwargs)
return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
# special methods bypass __getattr__, so they need to be added manually
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
# NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
# as self._meta is currently used), because then the following
# operations would by default not be wrapped, and so not propagated
# when the tensor is made eager.
# It's better to get non-silent errors for not-yet-supported operators.
# TODO: add more when needed to avoid clutter, or find a more concise way
def __neg__(self, *args): # mamba
return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
def __add__(self, *args): # gemma
return self._wrap_fn(torch.Tensor.__add__)(self, *args)
def __getitem__(self, *args): # bloom falcon internlm2
return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file") description="Convert a huggingface model to a GGML compatible file")
@ -2260,6 +2383,10 @@ def parse_args() -> argparse.Namespace:
"--use-temp-file", action="store_true", "--use-temp-file", action="store_true",
help="use the tempfile library while processing (helpful when running out of memory, process killed)", help="use the tempfile library while processing (helpful when running out of memory, process killed)",
) )
parser.add_argument(
"--no-lazy", action="store_true",
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
)
parser.add_argument( parser.add_argument(
"--model-name", type=str, default=None, "--model-name", type=str, default=None,
help="name of the model", help="name of the model",
@ -2313,7 +2440,7 @@ def main() -> None:
with torch.inference_mode(): with torch.inference_mode():
model_class = Model.from_model_architecture(hparams["architectures"][0]) model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file) model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file, args.no_lazy)
logger.info("Set model parameters") logger.info("Set model parameters")
model_instance.set_gguf_parameters() model_instance.set_gguf_parameters()

View file

@ -7,7 +7,7 @@ import struct
import tempfile import tempfile
from enum import Enum, auto from enum import Enum, auto
from io import BufferedWriter from io import BufferedWriter
from typing import IO, Any, Sequence, Mapping from typing import IO, Any, Callable, Sequence, Mapping
from string import ascii_letters, digits from string import ascii_letters, digits
import numpy as np import numpy as np
@ -28,6 +28,47 @@ from .constants import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LazyTensor:
data: Callable[[], np.ndarray[Any, Any]]
# to avoid too deep recursion
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
dtype: np.dtype[Any]
shape: tuple[int, ...]
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
self.data = data
self.functions = []
self.dtype = np.dtype(dtype)
self.shape = shape
def astype(self, dtype: type, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.astype(dtype, **kwargs))
self.dtype = np.dtype(dtype)
return self
@property
def nbytes(self) -> int:
size = 1
for n in self.shape:
size *= n
return size * self.dtype.itemsize
def tofile(self, *args, **kwargs) -> None:
data = self.data()
for f in self.functions:
data = f(data)
assert data.shape == self.shape
assert data.dtype == self.dtype
assert data.nbytes == self.nbytes
self.functions = []
self.data = lambda: data
data.tofile(*args, **kwargs)
def byteswap(self, *args, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
return self
class WriterState(Enum): class WriterState(Enum):
EMPTY = auto() EMPTY = auto()
HEADER = auto() HEADER = auto()
@ -38,7 +79,7 @@ class WriterState(Enum):
class GGUFWriter: class GGUFWriter:
fout: BufferedWriter fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[np.ndarray[Any, Any]] tensors: list[np.ndarray[Any, Any] | LazyTensor]
_simple_value_packing = { _simple_value_packing = {
GGUFValueType.UINT8: "B", GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b", GGUFValueType.INT8: "b",
@ -237,7 +278,7 @@ class GGUFWriter:
self.ti_data_count += 1 self.ti_data_count += 1
def add_tensor( def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None, raw_dtype: GGMLQuantizationType | None = None,
) -> None: ) -> None:
if self.endianess == GGUFEndian.BIG: if self.endianess == GGUFEndian.BIG:
@ -262,7 +303,7 @@ class GGUFWriter:
if pad != 0: if pad != 0:
fp.write(bytes([0] * pad)) fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
if self.state is not WriterState.TI_DATA: if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}') raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
@ -272,15 +313,33 @@ class GGUFWriter:
tensor.tofile(self.fout) tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes) self.write_padding(self.fout, tensor.nbytes)
def write_tensors_to_file(self) -> None: def write_tensors_to_file(self, *, progress: bool = False) -> None:
self.write_ti_data_to_file() self.write_ti_data_to_file()
self.write_padding(self.fout, self.fout.tell()) self.write_padding(self.fout, self.fout.tell())
if self.temp_file is None: if self.temp_file is None:
self.tensors.reverse() # to pop from the "beginning" in constant time
if progress:
from tqdm import tqdm
total_bytes = sum(t.nbytes for t in self.tensors)
bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True)
while True:
try:
tensor = self.tensors.pop()
except IndexError:
break
tensor.tofile(self.fout)
bar.update(tensor.nbytes)
self.write_padding(self.fout, tensor.nbytes)
return
while True: while True:
try: try:
tensor = self.tensors.pop(0) tensor = self.tensors.pop()
except IndexError: except IndexError:
break break
tensor.tofile(self.fout) tensor.tofile(self.fout)