convert-hf : support bfloat16 conversion

This commit is contained in:
Francis Couture-Harpin 2024-05-08 18:18:37 -04:00
parent 4426e2987b
commit 6f8d280073
5 changed files with 345 additions and 162 deletions

View file

@ -12,7 +12,7 @@ import sys
from enum import IntEnum
from pathlib import Path
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
import numpy as np
import torch
@ -65,8 +65,8 @@ class Model:
model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
if self.__class__ == Model:
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
if type(self) == Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
self.ftype = ftype
self.fname_out = fname_out
@ -215,6 +215,23 @@ class Model:
return False
def write_tensors(self):
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
def np_fp32_to_bf16(n: np.ndarray):
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, n | (64 << 16), n)
# flush subnormals to zero
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
# round to nearest even
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
return n
# Doing this row-wise is much, much faster than element-wise, hence the signature
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
if self.lazy:
# TODO: find a way to implicitly wrap np.vectorize functions
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
for name, data_torch in self.get_tensors():
@ -239,10 +256,7 @@ class Model:
data: np.ndarray = data # type hint
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
data_qtype: gguf.GGMLQuantizationType | None = None
# when both are True, f32 should win
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
@ -254,20 +268,33 @@ class Model:
# if f16 desired, convert any float32 2-dim weight tensors to float16
extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)
# when both extra_f32 and extra_f16 are False, convert to float32 by default
if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
data = data.astype(np.float32)
if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
if data_dtype != np.float16:
data = data.astype(np.float16)
data_qtype = gguf.GGMLQuantizationType.F16
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
if data_dtype != np.float32:
data = data.astype(np.float32)
data = v_fp32_to_bf16(data.view(np.int32))
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16
else: # by default, convert to float32
if data_dtype != np.float32:
data = data.astype(np.float32)
data_qtype = gguf.GGMLQuantizationType.F32
assert data_qtype is not None
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
self.gguf_writer.add_tensor(new_name, data)
self.gguf_writer.add_tensor(new_name, data,raw_dtype=data_qtype)
def write(self):
self.write_tensors()
@ -2281,92 +2308,40 @@ class OlmoModel(Model):
# tree of lazy tensors
class LazyTorchTensor:
_meta: Tensor
_data: Tensor | None
_args: tuple
_func: Callable[[tuple], Tensor] | None
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
self._meta = meta
self._data = data
self._args = args
self._func = func
@staticmethod
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
# TODO: dict and set
if isinstance(o, (list, tuple)):
L = []
for item in o:
L.append(LazyTorchTensor._recurse_apply(item, fn))
if isinstance(o, tuple):
L = tuple(L)
return L
elif isinstance(o, LazyTorchTensor):
return fn(o)
else:
return o
def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
args = ((self,) if use_self else ()) + args
meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
return wrapped_fn
def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
if callable(meta_attr):
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
elif isinstance(meta_attr, torch.Tensor):
# for things like self.T
return self._wrap_fn(lambda s: getattr(s, __name))(self)
else:
return meta_attr
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
# to keep the type-checker happy
dtype: torch.dtype
shape: torch.Size
# only used when converting a torch.Tensor to a np.ndarray
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
}
def numpy(self) -> gguf.LazyTensor:
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
return gguf.LazyNumpyTensor(
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
lazy=self._lazy,
args=(self,),
func=(lambda s: s[0].numpy())
)
@overload
@staticmethod
def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
@overload
@staticmethod
def to_eager(t: tuple) -> tuple: ...
@staticmethod
def to_eager(t: Any) -> Any:
def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
# wake up the lazy tensor
if _t._data is None and _t._func is not None:
# recurse into its arguments
_t._args = LazyTorchTensor.to_eager(_t._args)
_t._data = _t._func(_t._args)
if _t._data is not None:
return _t._data
else:
raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
# recurse into lists and/or tuples, keeping their structure
return LazyTorchTensor._recurse_apply(t, simple_to_eager)
@staticmethod
def from_eager(t: Tensor) -> Tensor:
if (t.__class__ == LazyTorchTensor):
@classmethod
def eager_to_meta(cls, t: Tensor) -> Tensor:
if t.is_meta:
return t
return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
return t.detach().to("meta")
@classmethod
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
m = m.detach()
if not m.is_meta:
m = m.to("meta")
m.dtype = dtype
return m
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
@ -2377,28 +2352,8 @@ class LazyTorchTensor:
if func is torch.Tensor.numpy:
return args[0].numpy()
if func is torch.equal:
eager_args = LazyTorchTensor.to_eager(args)
return func(*eager_args, **kwargs)
return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
# special methods bypass __getattr__, so they need to be added manually
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
# NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
# as self._meta is currently used), because then the following
# operations would by default not be wrapped, and so not propagated
# when the tensor is made eager.
# It's better to get non-silent errors for not-yet-supported operators.
# TODO: add more when needed to avoid clutter, or find a more concise way
def __neg__(self, *args): # mamba
return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
def __add__(self, *args): # gemma
return self._wrap_fn(torch.Tensor.__add__)(self, *args)
def __getitem__(self, *args): # bloom falcon refact internlm2
return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
def parse_args() -> argparse.Namespace:
@ -2417,8 +2372,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16"], default="f16",
help="output format - use f32 for float32, f16 for float16",
"--outtype", type=str, choices=["f32", "f16", "bf16"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16",
)
parser.add_argument(
"--bigendian", action="store_true",
@ -2472,9 +2427,10 @@ def main() -> None:
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
ftype_map = {
"f32": gguf.GGMLQuantizationType.F32,
"f16": gguf.GGMLQuantizationType.F16,
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
}
if args.outfile is not None:

View file

@ -1,4 +1,5 @@
from .constants import *
from .lazy import *
from .gguf_reader import *
from .gguf_writer import *
from .tensor_mapping import *

View file

@ -820,6 +820,47 @@ class GGMLQuantizationType(IntEnum):
BF16 = 30
# TODO: add GGMLFileType from ggml_ftype in ggml.h
# from llama_ftype in llama.h
# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
class LlamaFileType(IntEnum):
ALL_F32 = 0
MOSTLY_F16 = 1 # except 1d tensors
MOSTLY_Q4_0 = 2 # except 1d tensors
MOSTLY_Q4_1 = 3 # except 1d tensors
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
# MOSTLY_Q4_2 = 5 # support has been removed
# MOSTLY_Q4_3 = 6 # support has been removed
MOSTLY_Q8_0 = 7 # except 1d tensors
MOSTLY_Q5_0 = 8 # except 1d tensors
MOSTLY_Q5_1 = 9 # except 1d tensors
MOSTLY_Q2_K = 10 # except 1d tensors
MOSTLY_Q3_K_S = 11 # except 1d tensors
MOSTLY_Q3_K_M = 12 # except 1d tensors
MOSTLY_Q3_K_L = 13 # except 1d tensors
MOSTLY_Q4_K_S = 14 # except 1d tensors
MOSTLY_Q4_K_M = 15 # except 1d tensors
MOSTLY_Q5_K_S = 16 # except 1d tensors
MOSTLY_Q5_K_M = 17 # except 1d tensors
MOSTLY_Q6_K = 18 # except 1d tensors
MOSTLY_IQ2_XXS = 19 # except 1d tensors
MOSTLY_IQ2_XS = 20 # except 1d tensors
MOSTLY_Q2_K_S = 21 # except 1d tensors
MOSTLY_IQ3_XS = 22 # except 1d tensors
MOSTLY_IQ3_XXS = 23 # except 1d tensors
MOSTLY_IQ1_S = 24 # except 1d tensors
MOSTLY_IQ4_NL = 25 # except 1d tensors
MOSTLY_IQ3_S = 26 # except 1d tensors
MOSTLY_IQ3_M = 27 # except 1d tensors
MOSTLY_IQ2_S = 28 # except 1d tensors
MOSTLY_IQ2_M = 29 # except 1d tensors
MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors
class GGUFEndian(IntEnum):
LITTLE = 0
BIG = 1

View file

@ -7,7 +7,7 @@ import struct
import tempfile
from enum import Enum, auto
from io import BufferedWriter
from typing import IO, Any, Callable, Sequence, Mapping
from typing import IO, Any, Sequence, Mapping
from string import ascii_letters, digits
import numpy as np
@ -28,47 +28,6 @@ from .constants import (
logger = logging.getLogger(__name__)
class LazyTensor:
data: Callable[[], np.ndarray[Any, Any]]
# to avoid too deep recursion
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
dtype: np.dtype[Any]
shape: tuple[int, ...]
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
self.data = data
self.functions = []
self.dtype = np.dtype(dtype)
self.shape = shape
def astype(self, dtype: type, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.astype(dtype, **kwargs))
self.dtype = np.dtype(dtype)
return self
@property
def nbytes(self) -> int:
size = 1
for n in self.shape:
size *= n
return size * self.dtype.itemsize
def tofile(self, *args, **kwargs) -> None:
data = self.data()
for f in self.functions:
data = f(data)
assert data.shape == self.shape
assert data.dtype == self.dtype
assert data.nbytes == self.nbytes
self.functions = []
self.data = lambda: data
data.tofile(*args, **kwargs)
def byteswap(self, *args, **kwargs) -> LazyTensor:
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
return self
class WriterState(Enum):
EMPTY = auto()
HEADER = auto()
@ -79,7 +38,7 @@ class WriterState(Enum):
class GGUFWriter:
fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[np.ndarray[Any, Any] | LazyTensor]
tensors: list[np.ndarray[Any, Any]]
_simple_value_packing = {
GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b",
@ -278,7 +237,7 @@ class GGUFWriter:
self.ti_data_count += 1
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.endianess == GGUFEndian.BIG:
@ -303,7 +262,7 @@ class GGUFWriter:
if pad != 0:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')

226
gguf-py/gguf/lazy.py Normal file
View file

@ -0,0 +1,226 @@
from __future__ import annotations
from abc import ABC, ABCMeta, abstractmethod
import logging
from typing import Any, Callable, Iterable, SupportsIndex
from collections import deque
import numpy as np
from numpy.typing import DTypeLike, NDArray
from numpy._typing import _ShapeLike
logger = logging.getLogger(__name__)
class LazyMeta(ABCMeta):
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
if callable(meta_attr):
return type(self)._wrap_fn(
(lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
use_self=self,
)
elif isinstance(meta_attr, self._tensor_type):
# e.g. self.T with torch.Tensor should still be wrapped
return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
else:
# no need to wrap non-tensor properties,
# and they likely don't depend on the actual contents of the tensor
return meta_attr
namespace["__getattr__"] = __getattr__
# need to make a builder for the wrapped wrapper to copy the name,
# or else it fails with very cryptic error messages,
# because somehow the same string would end up in every closures
def mk_wrap(op_name: str, *, meta_noop: bool = False):
# need to wrap the wrapper to get self
def wrapped_special_op(self, *args, **kwargs):
return type(self)._wrap_fn(
getattr(type(self)._tensor_type, op_name),
meta_noop=meta_noop,
)(self, *args, **kwargs)
return wrapped_special_op
# special methods bypass __getattr__, so they need to be added manually
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
# NOTE: doing this from a metaclass is very convenient
# TODO: make this even more comprehensive
for binary_op in (
"lt", "le", "eq", "ne", "ge", "gt", "not"
"abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
"neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
"iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
"radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
):
attr_name = f"__{binary_op}__"
# the result of these operators usually has the same shape and dtype as the input,
# so evaluation on the meta tensor can be skipped.
namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
for special_op in (
"getitem", "setitem", "len",
):
attr_name = f"__{special_op}__"
namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
return super().__new__(cls, name, bases, namespace, **kwargs)
# Tree of lazy tensors
class LazyBase(ABC, metaclass=LazyMeta):
_tensor_type: type
_meta: Any
_data: Any | None
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
_args: tuple
_func: Callable[[tuple], Any] | None
def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
super().__init__()
self._meta = meta
self._data = data
self._lazy = lazy if lazy is not None else deque()
self._args = args
self._func = func
assert self._func is not None or self._data is not None
if self._data is None:
self._lazy.append(self)
def __init_subclass__(cls) -> None:
if "_tensor_type" not in cls.__dict__:
raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
return super().__init_subclass__()
@staticmethod
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
# TODO: dict and set
if isinstance(o, (list, tuple)):
L = []
for item in o:
L.append(LazyBase._recurse_apply(item, fn))
if isinstance(o, tuple):
L = tuple(L)
return L
elif isinstance(o, LazyBase):
return fn(o)
else:
return o
@classmethod
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
args = ((use_self,) if use_self is not None else ()) + args
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
if isinstance(meta_noop, bool) and not meta_noop:
try:
res = fn(*meta_args, **kwargs)
except NotImplementedError:
# running some operations on PyTorch's Meta tensors can cause this exception
res = None
else:
# some operators don't need to actually run on the meta tensors
assert len(args) > 0
res = args[0]
assert isinstance(res, cls)
res = res._meta
# allow operations to override the dtype
if meta_noop != True:
res = cls.meta_with_dtype(res, meta_noop)
if isinstance(res, cls._tensor_type):
def collect_replace(t: LazyBase):
if collect_replace.shared_lazy is None:
collect_replace.shared_lazy = t._lazy
else:
collect_replace.shared_lazy.extend(t._lazy)
t._lazy = collect_replace.shared_lazy
# emulating a static variable
collect_replace.shared_lazy = None
LazyBase._recurse_apply(args, collect_replace)
shared_lazy = collect_replace.shared_lazy
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
else:
del res # not needed
# non-tensor return likely relies on the contents of the args
# (e.g. the result of torch.equal)
eager_args = cls.to_eager(args)
return fn(*eager_args, **kwargs)
return wrapped_fn
@classmethod
def to_eager(cls, t: Any) -> Any:
def simple_to_eager(_t: LazyBase) -> Any:
def already_eager_to_eager(_t: LazyBase) -> Any:
assert _t._data is not None
return _t._data
while _t._data is None:
lt = _t._lazy.popleft()
if lt._data is not None:
raise ValueError(f"{lt} did not belong in the lazy queue")
assert lt._func is not None
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
lt._data = lt._func(lt._args)
# sanity check
assert lt._data.dtype == lt._meta.dtype
assert lt._data.shape == lt._meta.shape
return _t._data
# recurse into lists and/or tuples, keeping their structure
return cls._recurse_apply(t, simple_to_eager)
@classmethod
def eager_to_meta(cls, t: Any) -> Any:
return cls.meta_with_dtype(t, t.dtype)
# must be overridden, meta tensor init is backend-specific
@classmethod
@abstractmethod
def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
@classmethod
def from_eager(cls, t: Any) -> Any:
if type(t) == cls:
# already eager
return t
elif isinstance(t, cls._tensor_type):
return cls(meta=cls.eager_to_meta(t), data=t)
else:
return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
class LazyNumpyTensor(LazyBase):
_tensor_type = np.ndarray
@classmethod
def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
# The initial idea was to use np.nan as the fill value,
# but non-float types like np.int16 can't use that.
# So zero it is.
cheat = np.zeros(1, dtype)
return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype(self._meta, dtype)
full_args = (self, dtype,) + args
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)
return eager.tofile(*args, **kwargs)
# TODO: __array_function__