convert-hf : support bfloat16 conversion

This commit is contained in:
Francis Couture-Harpin 2024-05-08 18:18:37 -04:00
parent 4426e2987b
commit 6f8d280073
5 changed files with 345 additions and 162 deletions

View file

@ -12,7 +12,7 @@ import sys
from enum import IntEnum
from pathlib import Path
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Sequence, TypeVar, cast
import numpy as np
import torch
@ -65,8 +65,8 @@ class Model:
model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool, use_temp_file: bool, eager: bool):
if self.__class__ == Model:
raise TypeError(f"{self.__class__.__name__!r} should not be directly instantiated")
if type(self) == Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
self.dir_model = dir_model
self.ftype = ftype
self.fname_out = fname_out
@ -215,6 +215,23 @@ class Model:
return False
def write_tensors(self):
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
def np_fp32_to_bf16(n: np.ndarray):
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, n | (64 << 16), n)
# flush subnormals to zero
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
# round to nearest even
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
return n
# Doing this row-wise is much, much faster than element-wise, hence the signature
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
if self.lazy:
# TODO: find a way to implicitly wrap np.vectorize functions
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
for name, data_torch in self.get_tensors():
@ -239,10 +256,7 @@ class Model:
data: np.ndarray = data # type hint
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
data_qtype: gguf.GGMLQuantizationType | None = None
# when both are True, f32 should win
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
@ -254,20 +268,33 @@ class Model:
# if f16 desired, convert any float32 2-dim weight tensors to float16
extra_f16 = extra_f16 or (name.endswith(".weight") and n_dims >= 2)
# when both extra_f32 and extra_f16 are False, convert to float32 by default
if self.ftype == 1 and data_dtype == np.float16 and (extra_f32 or not extra_f16):
data = data.astype(np.float32)
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
if data_dtype != np.float16:
data = data.astype(np.float16)
data_qtype = gguf.GGMLQuantizationType.F16
if self.ftype == 1 and data_dtype == np.float32 and extra_f16 and not extra_f32:
data = data.astype(np.float16)
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
if data_dtype != np.float32:
data = data.astype(np.float32)
data = v_fp32_to_bf16(data.view(np.int32))
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16
else: # by default, convert to float32
if data_dtype != np.float32:
data = data.astype(np.float32)
data_qtype = gguf.GGMLQuantizationType.F32
assert data_qtype is not None
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data.dtype}, shape = {shape_str}")
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
self.gguf_writer.add_tensor(new_name, data)
self.gguf_writer.add_tensor(new_name, data,raw_dtype=data_qtype)
def write(self):
self.write_tensors()
@ -2281,92 +2308,40 @@ class OlmoModel(Model):
# tree of lazy tensors
class LazyTorchTensor:
_meta: Tensor
_data: Tensor | None
_args: tuple
_func: Callable[[tuple], Tensor] | None
def __init__(self, *, meta: Tensor, data: Tensor | None = None, args: tuple = (), func: Callable[[tuple], Tensor] | None = None):
self._meta = meta
self._data = data
self._args = args
self._func = func
@staticmethod
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
# TODO: dict and set
if isinstance(o, (list, tuple)):
L = []
for item in o:
L.append(LazyTorchTensor._recurse_apply(item, fn))
if isinstance(o, tuple):
L = tuple(L)
return L
elif isinstance(o, LazyTorchTensor):
return fn(o)
else:
return o
def _wrap_fn(self, fn: Callable, use_self: bool = False) -> Callable[[Any], LazyTorchTensor]:
def wrapped_fn(*args, **kwargs):
if kwargs is None:
kwargs = {}
args = ((self,) if use_self else ()) + args
meta_args = LazyTorchTensor._recurse_apply(args, lambda t: t._meta)
return LazyTorchTensor(meta=fn(*meta_args, **kwargs), args=args, func=lambda a: fn(*a, **kwargs))
return wrapped_fn
def __getattr__(self, __name: str) -> Any:
meta_attr = getattr(self._meta, __name)
if callable(meta_attr):
return self._wrap_fn(getattr(torch.Tensor, __name), use_self=True)
elif isinstance(meta_attr, torch.Tensor):
# for things like self.T
return self._wrap_fn(lambda s: getattr(s, __name))(self)
else:
return meta_attr
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
# to keep the type-checker happy
dtype: torch.dtype
shape: torch.Size
# only used when converting a torch.Tensor to a np.ndarray
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
}
def numpy(self) -> gguf.LazyTensor:
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyTensor(lambda: LazyTorchTensor.to_eager(self).numpy(), dtype=dtype, shape=self.shape)
return gguf.LazyNumpyTensor(
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
lazy=self._lazy,
args=(self,),
func=(lambda s: s[0].numpy())
)
@overload
@staticmethod
def to_eager(t: Tensor | LazyTorchTensor) -> Tensor: ...
@overload
@staticmethod
def to_eager(t: tuple) -> tuple: ...
@staticmethod
def to_eager(t: Any) -> Any:
def simple_to_eager(_t: LazyTorchTensor) -> Tensor:
# wake up the lazy tensor
if _t._data is None and _t._func is not None:
# recurse into its arguments
_t._args = LazyTorchTensor.to_eager(_t._args)
_t._data = _t._func(_t._args)
if _t._data is not None:
return _t._data
else:
raise ValueError(f"Could not compute lazy tensor {_t!r} with args {_t._args!r}")
# recurse into lists and/or tuples, keeping their structure
return LazyTorchTensor._recurse_apply(t, simple_to_eager)
@staticmethod
def from_eager(t: Tensor) -> Tensor:
if (t.__class__ == LazyTorchTensor):
@classmethod
def eager_to_meta(cls, t: Tensor) -> Tensor:
if t.is_meta:
return t
return LazyTorchTensor(meta=t.detach().to("meta"), data=t) # type: ignore
return t.detach().to("meta")
@classmethod
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
m = m.detach()
if not m.is_meta:
m = m.to("meta")
m.dtype = dtype
return m
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
@ -2377,28 +2352,8 @@ class LazyTorchTensor:
if func is torch.Tensor.numpy:
return args[0].numpy()
if func is torch.equal:
eager_args = LazyTorchTensor.to_eager(args)
return func(*eager_args, **kwargs)
return LazyTorchTensor._wrap_fn(args[0], func)(*args, **kwargs)
# special methods bypass __getattr__, so they need to be added manually
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
# NOTE: LazyTorchTensor can't be a subclass of Tensor (and then be used
# as self._meta is currently used), because then the following
# operations would by default not be wrapped, and so not propagated
# when the tensor is made eager.
# It's better to get non-silent errors for not-yet-supported operators.
# TODO: add more when needed to avoid clutter, or find a more concise way
def __neg__(self, *args): # mamba
return self._wrap_fn(torch.Tensor.__neg__)(self, *args)
def __add__(self, *args): # gemma
return self._wrap_fn(torch.Tensor.__add__)(self, *args)
def __getitem__(self, *args): # bloom falcon refact internlm2
return self._wrap_fn(torch.Tensor.__getitem__)(self, *args)
return LazyTorchTensor._wrap_fn(func)(*args, **kwargs)
def parse_args() -> argparse.Namespace:
@ -2417,8 +2372,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16"], default="f16",
help="output format - use f32 for float32, f16 for float16",
"--outtype", type=str, choices=["f32", "f16", "bf16"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16",
)
parser.add_argument(
"--bigendian", action="store_true",
@ -2472,9 +2427,10 @@ def main() -> None:
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
ftype_map = {
"f32": gguf.GGMLQuantizationType.F32,
"f16": gguf.GGMLQuantizationType.F16,
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
}
if args.outfile is not None: