convert-hf : support bfloat16 conversion (#7158)
* convert-hf : support bfloat16 conversion * gguf-py : flake8 fixes * convert-hf : add missing space after comma * convert-hf : get bit-exact same output as ./quantize The quantization version was missing. * convert-hf : don't round bf16 NANs * convert-hf : save some memory with np.int16 intermediate bf16 weights * convert-hf : more closely match llama.cpp with which weights to keep in f32 * convert-hf : add --outtype auto-f16 A reason for this to exist is for model quantizers who want an initial GGUF with the most fidelity to the original model while still using a 16-bit float type instead of 32-bit floats. * convert-hf : remove a semicolon because flake8 doesn't like it It's a reflex from when programming in C/C++, I guess. * convert-hf : support outtype templating in outfile name * convert-hf : rename --outtype auto-f16 to --outtype auto
This commit is contained in:
parent
fae9d234b6
commit
5a419926b0
5 changed files with 404 additions and 182 deletions
|
@ -1,4 +1,5 @@
|
|||
from .constants import *
|
||||
from .lazy import *
|
||||
from .gguf_reader import *
|
||||
from .gguf_writer import *
|
||||
from .tensor_mapping import *
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
GGUF_MAGIC = 0x46554747 # "GGUF"
|
||||
GGUF_VERSION = 3
|
||||
GGUF_DEFAULT_ALIGNMENT = 32
|
||||
GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
|
||||
|
||||
#
|
||||
# metadata keys
|
||||
|
@ -838,6 +839,49 @@ class GGMLQuantizationType(IntEnum):
|
|||
BF16 = 30
|
||||
|
||||
|
||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||
|
||||
|
||||
# from llama_ftype in llama.h
|
||||
# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
|
||||
class LlamaFileType(IntEnum):
|
||||
ALL_F32 = 0
|
||||
MOSTLY_F16 = 1 # except 1d tensors
|
||||
MOSTLY_Q4_0 = 2 # except 1d tensors
|
||||
MOSTLY_Q4_1 = 3 # except 1d tensors
|
||||
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
|
||||
# MOSTLY_Q4_2 = 5 # support has been removed
|
||||
# MOSTLY_Q4_3 = 6 # support has been removed
|
||||
MOSTLY_Q8_0 = 7 # except 1d tensors
|
||||
MOSTLY_Q5_0 = 8 # except 1d tensors
|
||||
MOSTLY_Q5_1 = 9 # except 1d tensors
|
||||
MOSTLY_Q2_K = 10 # except 1d tensors
|
||||
MOSTLY_Q3_K_S = 11 # except 1d tensors
|
||||
MOSTLY_Q3_K_M = 12 # except 1d tensors
|
||||
MOSTLY_Q3_K_L = 13 # except 1d tensors
|
||||
MOSTLY_Q4_K_S = 14 # except 1d tensors
|
||||
MOSTLY_Q4_K_M = 15 # except 1d tensors
|
||||
MOSTLY_Q5_K_S = 16 # except 1d tensors
|
||||
MOSTLY_Q5_K_M = 17 # except 1d tensors
|
||||
MOSTLY_Q6_K = 18 # except 1d tensors
|
||||
MOSTLY_IQ2_XXS = 19 # except 1d tensors
|
||||
MOSTLY_IQ2_XS = 20 # except 1d tensors
|
||||
MOSTLY_Q2_K_S = 21 # except 1d tensors
|
||||
MOSTLY_IQ3_XS = 22 # except 1d tensors
|
||||
MOSTLY_IQ3_XXS = 23 # except 1d tensors
|
||||
MOSTLY_IQ1_S = 24 # except 1d tensors
|
||||
MOSTLY_IQ4_NL = 25 # except 1d tensors
|
||||
MOSTLY_IQ3_S = 26 # except 1d tensors
|
||||
MOSTLY_IQ3_M = 27 # except 1d tensors
|
||||
MOSTLY_IQ2_S = 28 # except 1d tensors
|
||||
MOSTLY_IQ2_M = 29 # except 1d tensors
|
||||
MOSTLY_IQ4_XS = 30 # except 1d tensors
|
||||
MOSTLY_IQ1_M = 31 # except 1d tensors
|
||||
MOSTLY_BF16 = 32 # except 1d tensors
|
||||
|
||||
GUESSED = 1024 # not specified in the model file
|
||||
|
||||
|
||||
class GGUFEndian(IntEnum):
|
||||
LITTLE = 0
|
||||
BIG = 1
|
||||
|
|
|
@ -7,7 +7,7 @@ import struct
|
|||
import tempfile
|
||||
from enum import Enum, auto
|
||||
from io import BufferedWriter
|
||||
from typing import IO, Any, Callable, Sequence, Mapping
|
||||
from typing import IO, Any, Sequence, Mapping
|
||||
from string import ascii_letters, digits
|
||||
|
||||
import numpy as np
|
||||
|
@ -28,47 +28,6 @@ from .constants import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LazyTensor:
|
||||
data: Callable[[], np.ndarray[Any, Any]]
|
||||
# to avoid too deep recursion
|
||||
functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
|
||||
dtype: np.dtype[Any]
|
||||
shape: tuple[int, ...]
|
||||
|
||||
def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
|
||||
self.data = data
|
||||
self.functions = []
|
||||
self.dtype = np.dtype(dtype)
|
||||
self.shape = shape
|
||||
|
||||
def astype(self, dtype: type, **kwargs) -> LazyTensor:
|
||||
self.functions.append(lambda n: n.astype(dtype, **kwargs))
|
||||
self.dtype = np.dtype(dtype)
|
||||
return self
|
||||
|
||||
@property
|
||||
def nbytes(self) -> int:
|
||||
size = 1
|
||||
for n in self.shape:
|
||||
size *= n
|
||||
return size * self.dtype.itemsize
|
||||
|
||||
def tofile(self, *args, **kwargs) -> None:
|
||||
data = self.data()
|
||||
for f in self.functions:
|
||||
data = f(data)
|
||||
assert data.shape == self.shape
|
||||
assert data.dtype == self.dtype
|
||||
assert data.nbytes == self.nbytes
|
||||
self.functions = []
|
||||
self.data = lambda: data
|
||||
data.tofile(*args, **kwargs)
|
||||
|
||||
def byteswap(self, *args, **kwargs) -> LazyTensor:
|
||||
self.functions.append(lambda n: n.byteswap(*args, **kwargs))
|
||||
return self
|
||||
|
||||
|
||||
class WriterState(Enum):
|
||||
EMPTY = auto()
|
||||
HEADER = auto()
|
||||
|
@ -79,7 +38,7 @@ class WriterState(Enum):
|
|||
class GGUFWriter:
|
||||
fout: BufferedWriter
|
||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
|
||||
tensors: list[np.ndarray[Any, Any] | LazyTensor]
|
||||
tensors: list[np.ndarray[Any, Any]]
|
||||
_simple_value_packing = {
|
||||
GGUFValueType.UINT8: "B",
|
||||
GGUFValueType.INT8: "b",
|
||||
|
@ -278,7 +237,7 @@ class GGUFWriter:
|
|||
self.ti_data_count += 1
|
||||
|
||||
def add_tensor(
|
||||
self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
|
||||
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
|
||||
raw_dtype: GGMLQuantizationType | None = None,
|
||||
) -> None:
|
||||
if self.endianess == GGUFEndian.BIG:
|
||||
|
@ -303,7 +262,7 @@ class GGUFWriter:
|
|||
if pad != 0:
|
||||
fp.write(bytes([0] * pad))
|
||||
|
||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
|
||||
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
||||
if self.state is not WriterState.TI_DATA:
|
||||
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
|
||||
|
||||
|
@ -391,7 +350,7 @@ class GGUFWriter:
|
|||
def add_name(self, name: str) -> None:
|
||||
self.add_string(Keys.General.NAME, name)
|
||||
|
||||
def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
|
||||
def add_quantization_version(self, quantization_version: int) -> None:
|
||||
self.add_uint32(
|
||||
Keys.General.QUANTIZATION_VERSION, quantization_version)
|
||||
|
||||
|
|
225
gguf-py/gguf/lazy.py
Normal file
225
gguf-py/gguf/lazy.py
Normal file
|
@ -0,0 +1,225 @@
|
|||
from __future__ import annotations
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LazyMeta(ABCMeta):
|
||||
|
||||
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
meta_attr = getattr(self._meta, __name)
|
||||
if callable(meta_attr):
|
||||
return type(self)._wrap_fn(
|
||||
(lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
|
||||
use_self=self,
|
||||
)
|
||||
elif isinstance(meta_attr, self._tensor_type):
|
||||
# e.g. self.T with torch.Tensor should still be wrapped
|
||||
return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
|
||||
else:
|
||||
# no need to wrap non-tensor properties,
|
||||
# and they likely don't depend on the actual contents of the tensor
|
||||
return meta_attr
|
||||
|
||||
namespace["__getattr__"] = __getattr__
|
||||
|
||||
# need to make a builder for the wrapped wrapper to copy the name,
|
||||
# or else it fails with very cryptic error messages,
|
||||
# because somehow the same string would end up in every closures
|
||||
def mk_wrap(op_name: str, *, meta_noop: bool = False):
|
||||
# need to wrap the wrapper to get self
|
||||
def wrapped_special_op(self, *args, **kwargs):
|
||||
return type(self)._wrap_fn(
|
||||
getattr(type(self)._tensor_type, op_name),
|
||||
meta_noop=meta_noop,
|
||||
)(self, *args, **kwargs)
|
||||
return wrapped_special_op
|
||||
|
||||
# special methods bypass __getattr__, so they need to be added manually
|
||||
# ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
|
||||
# NOTE: doing this from a metaclass is very convenient
|
||||
# TODO: make this even more comprehensive
|
||||
for binary_op in (
|
||||
"lt", "le", "eq", "ne", "ge", "gt", "not"
|
||||
"abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
|
||||
"neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
|
||||
"iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
|
||||
"radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
|
||||
):
|
||||
attr_name = f"__{binary_op}__"
|
||||
# the result of these operators usually has the same shape and dtype as the input,
|
||||
# so evaluation on the meta tensor can be skipped.
|
||||
namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
|
||||
|
||||
for special_op in (
|
||||
"getitem", "setitem", "len",
|
||||
):
|
||||
attr_name = f"__{special_op}__"
|
||||
namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
|
||||
|
||||
return super().__new__(cls, name, bases, namespace, **kwargs)
|
||||
|
||||
|
||||
# Tree of lazy tensors
|
||||
class LazyBase(ABC, metaclass=LazyMeta):
|
||||
_tensor_type: type
|
||||
_meta: Any
|
||||
_data: Any | None
|
||||
_lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
|
||||
_args: tuple
|
||||
_func: Callable[[tuple], Any] | None
|
||||
|
||||
def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
|
||||
super().__init__()
|
||||
self._meta = meta
|
||||
self._data = data
|
||||
self._lazy = lazy if lazy is not None else deque()
|
||||
self._args = args
|
||||
self._func = func
|
||||
assert self._func is not None or self._data is not None
|
||||
if self._data is None:
|
||||
self._lazy.append(self)
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
if "_tensor_type" not in cls.__dict__:
|
||||
raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
|
||||
return super().__init_subclass__()
|
||||
|
||||
@staticmethod
|
||||
def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
|
||||
# TODO: dict and set
|
||||
if isinstance(o, (list, tuple)):
|
||||
L = []
|
||||
for item in o:
|
||||
L.append(LazyBase._recurse_apply(item, fn))
|
||||
if isinstance(o, tuple):
|
||||
L = tuple(L)
|
||||
return L
|
||||
elif isinstance(o, LazyBase):
|
||||
return fn(o)
|
||||
else:
|
||||
return o
|
||||
|
||||
@classmethod
|
||||
def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]:
|
||||
def wrapped_fn(*args, **kwargs):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
args = ((use_self,) if use_self is not None else ()) + args
|
||||
|
||||
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
|
||||
|
||||
if isinstance(meta_noop, bool) and not meta_noop:
|
||||
try:
|
||||
res = fn(*meta_args, **kwargs)
|
||||
except NotImplementedError:
|
||||
# running some operations on PyTorch's Meta tensors can cause this exception
|
||||
res = None
|
||||
else:
|
||||
# some operators don't need to actually run on the meta tensors
|
||||
assert len(args) > 0
|
||||
res = args[0]
|
||||
assert isinstance(res, cls)
|
||||
res = res._meta
|
||||
# allow operations to override the dtype
|
||||
if meta_noop is not True:
|
||||
res = cls.meta_with_dtype(res, meta_noop)
|
||||
|
||||
if isinstance(res, cls._tensor_type):
|
||||
def collect_replace(t: LazyBase):
|
||||
if collect_replace.shared_lazy is None:
|
||||
collect_replace.shared_lazy = t._lazy
|
||||
else:
|
||||
collect_replace.shared_lazy.extend(t._lazy)
|
||||
t._lazy = collect_replace.shared_lazy
|
||||
|
||||
# emulating a static variable
|
||||
collect_replace.shared_lazy = None
|
||||
|
||||
LazyBase._recurse_apply(args, collect_replace)
|
||||
|
||||
shared_lazy = collect_replace.shared_lazy
|
||||
|
||||
return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
|
||||
else:
|
||||
del res # not needed
|
||||
# non-tensor return likely relies on the contents of the args
|
||||
# (e.g. the result of torch.equal)
|
||||
eager_args = cls.to_eager(args)
|
||||
return fn(*eager_args, **kwargs)
|
||||
return wrapped_fn
|
||||
|
||||
@classmethod
|
||||
def to_eager(cls, t: Any) -> Any:
|
||||
def simple_to_eager(_t: LazyBase) -> Any:
|
||||
def already_eager_to_eager(_t: LazyBase) -> Any:
|
||||
assert _t._data is not None
|
||||
return _t._data
|
||||
|
||||
while _t._data is None:
|
||||
lt = _t._lazy.popleft()
|
||||
if lt._data is not None:
|
||||
raise ValueError(f"{lt} did not belong in the lazy queue")
|
||||
assert lt._func is not None
|
||||
lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
|
||||
lt._data = lt._func(lt._args)
|
||||
# sanity check
|
||||
assert lt._data.dtype == lt._meta.dtype
|
||||
assert lt._data.shape == lt._meta.shape
|
||||
|
||||
return _t._data
|
||||
|
||||
# recurse into lists and/or tuples, keeping their structure
|
||||
return cls._recurse_apply(t, simple_to_eager)
|
||||
|
||||
@classmethod
|
||||
def eager_to_meta(cls, t: Any) -> Any:
|
||||
return cls.meta_with_dtype(t, t.dtype)
|
||||
|
||||
# must be overridden, meta tensor init is backend-specific
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass
|
||||
|
||||
@classmethod
|
||||
def from_eager(cls, t: Any) -> Any:
|
||||
if type(t) is cls:
|
||||
# already eager
|
||||
return t
|
||||
elif isinstance(t, cls._tensor_type):
|
||||
return cls(meta=cls.eager_to_meta(t), data=t)
|
||||
else:
|
||||
return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
|
||||
|
||||
|
||||
class LazyNumpyTensor(LazyBase):
|
||||
_tensor_type = np.ndarray
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]:
|
||||
# The initial idea was to use np.nan as the fill value,
|
||||
# but non-float types like np.int16 can't use that.
|
||||
# So zero it is.
|
||||
cheat = np.zeros(1, dtype)
|
||||
return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape))
|
||||
|
||||
def astype(self, dtype, *args, **kwargs):
|
||||
meta = type(self).meta_with_dtype(self._meta, dtype)
|
||||
full_args = (self, dtype,) + args
|
||||
# very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
|
||||
return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
|
||||
|
||||
def tofile(self, *args, **kwargs):
|
||||
eager = LazyNumpyTensor.to_eager(self)
|
||||
return eager.tofile(*args, **kwargs)
|
||||
|
||||
# TODO: __array_function__
|
Loading…
Add table
Add a link
Reference in a new issue