Compare commits

...
Sign in to create a new pull request.

5 commits

Author SHA1 Message Date
Francis Couture-Harpin
229c35cb59 gguf-py : remove LlamaFileTypeMap
Too specific to 'llama.cpp', and would be a maintenance burden
to keep up to date.

* gguf-py : add generic quantize and dequantize functions

The quant classes no longer need to be known,
only the target or the source type,
for 'quantize' and 'dequantize', respectively.
2024-08-03 21:22:37 -04:00
Francis Couture-Harpin
e82ff5a346 gguf-py : fix BF16 numpy view type 2024-08-02 17:42:46 -04:00
Francis Couture-Harpin
861265b91e gguf-py : fix flake8 lint 2024-08-02 16:23:30 -04:00
Francis Couture-Harpin
5e27e7e11c convert_hf : simplify internal quantization type selection 2024-08-02 16:14:49 -04:00
Francis Couture-Harpin
1ac1a79161 gguf-py : use classes for quants 2024-08-02 15:18:16 -04:00
4 changed files with 226 additions and 132 deletions

View file

@ -251,12 +251,7 @@ class Model:
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid, n_dims # unused
return False
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del name, new_name, bid, n_dims # unused del name, new_name, bid, n_dims # unused
return False return False
@ -285,55 +280,47 @@ class Model:
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
data: np.ndarray # type hint data: np.ndarray # type hint
n_dims = len(data.shape) n_dims = len(data.shape)
data_dtype = data.dtype data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
data_qtype: gguf.GGMLQuantizationType | None = None
# when both are True, f32 should win
extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims)
extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims)
# Most of the codebase that takes in 1D tensors or norms only handles F32 tensors # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp if n_dims <= 1 or new_name.endswith("_norm.weight"):
extra_f32 = any(cond for cond in ( data_qtype = gguf.GGMLQuantizationType.F32
extra_f32,
n_dims == 1,
new_name.endswith("_norm.weight"),
))
# Conditions should closely match those in llama_model_quantize_internal in llama.cpp
# Some tensor types are always in float32 # Some tensor types are always in float32
extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( if data_qtype is False and (
any(
self.match_model_tensor_name(new_name, key, bid)
for key in (
gguf.MODEL_TENSOR.FFN_GATE_INP, gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD, gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES, gguf.MODEL_TENSOR.TOKEN_TYPES,
)) )
)
# if f16 desired, convert any float32 2-dim weight tensors to float16 or not name.endswith(".weight")
extra_f16 = any(cond for cond in ( ):
extra_f16,
(name.endswith(".weight") and n_dims >= 2),
))
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data = gguf.quantize_bf16(data)
assert data.dtype == np.uint16
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
data = gguf.quantize_q8_0(data)
assert data.dtype == np.uint8
data_qtype = gguf.GGMLQuantizationType.Q8_0
else: # default to float16 for quantized tensors
if data_dtype != np.float16:
data = data.astype(np.float16)
data_qtype = gguf.GGMLQuantizationType.F16
if data_qtype is None: # by default, convert to float32
if data_dtype != np.float32:
data = data.astype(np.float32)
data_qtype = gguf.GGMLQuantizationType.F32 data_qtype = gguf.GGMLQuantizationType.F32
# No override (data_qtype is False), or wants to be quantized (data_qtype is True)
if isinstance(data_qtype, bool):
if self.ftype == gguf.LlamaFileType.ALL_F32:
data_qtype = gguf.GGMLQuantizationType.F32
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
data_qtype = gguf.GGMLQuantizationType.F16
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
data_qtype = gguf.GGMLQuantizationType.Q8_0
else:
raise ValueError(f"Unknown file type: {self.ftype.name}")
try:
data = gguf.quants.quantize(data, data_qtype)
except gguf.QuantError as e:
logger.warning("%s, %s", e, "falling back to F16")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
# reverse shape to make it similar to the internal ggml dimension order # reverse shape to make it similar to the internal ggml dimension order
@ -1765,7 +1752,7 @@ class DbrxModel(Model):
return [(new_name, data_torch)] return [(new_name, data_torch)]
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid # unused del name, new_name, bid # unused
return n_dims > 1 return n_dims > 1
@ -2680,18 +2667,22 @@ class MambaModel(Model):
return [(new_name, data_torch)] return [(new_name, data_torch)]
def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del n_dims # unused if bid is not None and new_name in (
self.format_tensor_name(
return bid is not None and new_name in ( n, bid, ".weight" if name.endswith(".weight") else ""
self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [ )
for n in [
gguf.MODEL_TENSOR.SSM_CONV1D, gguf.MODEL_TENSOR.SSM_CONV1D,
gguf.MODEL_TENSOR.SSM_X, gguf.MODEL_TENSOR.SSM_X,
gguf.MODEL_TENSOR.SSM_DT, gguf.MODEL_TENSOR.SSM_DT,
gguf.MODEL_TENSOR.SSM_A, gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D, gguf.MODEL_TENSOR.SSM_D,
] ]
) ):
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
@Model.register("CohereForCausalLM") @Model.register("CohereForCausalLM")

View file

@ -1145,6 +1145,9 @@ class GGMLQuantizationType(IntEnum):
F64 = 28 F64 = 28
IQ1_M = 29 IQ1_M = 29
BF16 = 30 BF16 = 30
Q4_0_4_4 = 31
Q4_0_4_8 = 32
Q4_0_8_8 = 33
# TODO: add GGMLFileType from ggml_ftype in ggml.h # TODO: add GGMLFileType from ggml_ftype in ggml.h
@ -1157,7 +1160,7 @@ class LlamaFileType(IntEnum):
MOSTLY_F16 = 1 # except 1d tensors MOSTLY_F16 = 1 # except 1d tensors
MOSTLY_Q4_0 = 2 # except 1d tensors MOSTLY_Q4_0 = 2 # except 1d tensors
MOSTLY_Q4_1 = 3 # except 1d tensors MOSTLY_Q4_1 = 3 # except 1d tensors
MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 # MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
# MOSTLY_Q4_2 = 5 # support has been removed # MOSTLY_Q4_2 = 5 # support has been removed
# MOSTLY_Q4_3 = 6 # support has been removed # MOSTLY_Q4_3 = 6 # support has been removed
MOSTLY_Q8_0 = 7 # except 1d tensors MOSTLY_Q8_0 = 7 # except 1d tensors
@ -1186,6 +1189,9 @@ class LlamaFileType(IntEnum):
MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors
MOSTLY_Q4_0_4_4 = 33 # except 1d tensors
MOSTLY_Q4_0_4_8 = 34 # except 1d tensors
MOSTLY_Q4_0_8_8 = 35 # except 1d tensors
GUESSED = 1024 # not specified in the model file GUESSED = 1024 # not specified in the model file
@ -1259,6 +1265,9 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.F64: (1, 8),
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
GGMLQuantizationType.BF16: (1, 2), GGMLQuantizationType.BF16: (1, 2),
GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16),
GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16),
GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
} }

View file

@ -191,6 +191,8 @@ class LazyBase(ABC, metaclass=LazyMeta):
class LazyNumpyTensor(LazyBase): class LazyNumpyTensor(LazyBase):
_tensor_type = np.ndarray _tensor_type = np.ndarray
shape: tuple[int, ...] # Makes the type checker happy in quants.py
@classmethod @classmethod
def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]: def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
# The initial idea was to use np.nan as the fill value, # The initial idea was to use np.nan as the fill value,

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Callable, Sequence from abc import ABC, abstractmethod
from typing import Any, Callable, Sequence
from numpy.typing import DTypeLike from numpy.typing import DTypeLike
@ -9,32 +10,22 @@ from .lazy import LazyNumpyTensor
import numpy as np import numpy as np
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType): def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
block_size, type_size = GGML_QUANT_SIZES[quant_type] block_size, type_size = GGML_QUANT_SIZES[quant_type]
if shape[-1] % block_size != 0: if shape[-1] % block_size != 0:
raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})") raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
return (*shape[:-1], shape[-1] // block_size * type_size) return (*shape[:-1], shape[-1] // block_size * type_size)
def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType): def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]:
block_size, type_size = GGML_QUANT_SIZES[quant_type] block_size, type_size = GGML_QUANT_SIZES[quant_type]
if shape[-1] % type_size != 0: if shape[-1] % type_size != 0:
raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})") raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})")
return (*shape[:-1], shape[-1] // type_size * block_size) return (*shape[:-1], shape[-1] // type_size * block_size)
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
n = n.astype(np.float32, copy=False).view(np.uint32)
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
# round to nearest even
n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
return n.astype(np.uint16)
# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time # This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray: def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
rows = arr.reshape((-1, arr.shape[-1])) rows = arr.reshape((-1, arr.shape[-1]))
osize = 1 osize = 1
for dim in oshape: for dim in oshape:
@ -46,27 +37,6 @@ def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.
return out.reshape(oshape) return out.reshape(oshape)
def __quantize_bf16_array(n: np.ndarray) -> np.ndarray:
return __apply_over_grouped_rows(__compute_fp32_to_bf16, arr=n, otype=np.uint16, oshape=n.shape)
__quantize_bf16_lazy = LazyNumpyTensor._wrap_fn(__quantize_bf16_array, meta_noop=np.uint16)
def quantize_bf16(n: np.ndarray):
if type(n) is LazyNumpyTensor:
return __quantize_bf16_lazy(n)
else:
return __quantize_bf16_array(n)
__q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0]
def can_quantize_to_q8_0(n: np.ndarray) -> bool:
return n.shape[-1] % __q8_block_size == 0
# round away from zero # round away from zero
# ref: https://stackoverflow.com/a/59143326/22827863 # ref: https://stackoverflow.com/a/59143326/22827863
def np_roundf(n: np.ndarray) -> np.ndarray: def np_roundf(n: np.ndarray) -> np.ndarray:
@ -76,18 +46,151 @@ def np_roundf(n: np.ndarray) -> np.ndarray:
return np.sign(n) * b return np.sign(n) * b
def __quantize_q8_0_shape_change(s: tuple[int, ...]) -> tuple[int, ...]: class QuantError(Exception): ...
return (*s[:-1], s[-1] // __q8_block_size * __q8_type_size)
# Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c _type_traits: dict[GGMLQuantizationType, type[__Quant]] = {}
def __quantize_q8_0_rows(n: np.ndarray) -> np.ndarray:
shape = n.shape
assert shape[-1] % __q8_block_size == 0
n_blocks = n.size // __q8_block_size
blocks = n.reshape((n_blocks, __q8_block_size)).astype(np.float32, copy=False) def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
if qtype == GGMLQuantizationType.F32:
return data.astype(np.float32, copy=False)
elif qtype == GGMLQuantizationType.F16:
return data.astype(np.float16, copy=False)
elif (q := _type_traits.get(qtype)) is not None:
return q.quantize(data)
else:
raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")
def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
if qtype == GGMLQuantizationType.F32 or qtype == GGMLQuantizationType.F16:
return data.astype(np.float32, copy=False)
elif (q := _type_traits.get(qtype)) is not None:
return q.dequantize(data)
else:
raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented")
class __Quant(ABC):
qtype: GGMLQuantizationType
block_size: int
type_size: int
def __init__(self):
return TypeError("Quant conversion classes can't have instances")
def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None:
cls.qtype = qtype
cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype]
cls.__quantize_lazy = LazyNumpyTensor._wrap_fn(
cls.__quantize_array,
meta_noop=(np.uint8, cls.__shape_to_bytes)
)
cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn(
cls.__dequantize_array,
meta_noop=(np.float32, cls.__shape_from_bytes)
)
assert qtype not in _type_traits
_type_traits[qtype] = cls
@classmethod
@abstractmethod
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
raise NotImplementedError
@classmethod
@abstractmethod
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
raise NotImplementedError
@classmethod
def quantize_rows(cls, rows: np.ndarray) -> np.ndarray:
rows = rows.astype(np.float32, copy=False)
shape = rows.shape
n_blocks = rows.size // cls.block_size
blocks = rows.reshape((n_blocks, cls.block_size))
blocks = cls.quantize_blocks(blocks)
assert blocks.dtype == np.uint8
assert blocks.shape[-1] == cls.type_size
return blocks.reshape(cls.__shape_to_bytes(shape))
@classmethod
def dequantize_rows(cls, rows: np.ndarray) -> np.ndarray:
rows = rows.view(np.uint8)
shape = rows.shape
n_blocks = rows.size // cls.type_size
blocks = rows.reshape((n_blocks, cls.type_size))
blocks = cls.dequantize_blocks(blocks)
assert blocks.dtype == np.float32
assert blocks.shape[-1] == cls.block_size
return blocks.reshape(cls.__shape_from_bytes(shape))
@classmethod
def __shape_to_bytes(cls, shape: Sequence[int]):
return quant_shape_to_byte_shape(shape, cls.qtype)
@classmethod
def __shape_from_bytes(cls, shape: Sequence[int]):
return quant_shape_from_byte_shape(shape, cls.qtype)
@classmethod
def __quantize_array(cls, array: np.ndarray) -> np.ndarray:
return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape))
@classmethod
def __dequantize_array(cls, array: np.ndarray) -> np.ndarray:
return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape))
@classmethod
def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
pass
@classmethod
def __dequantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any:
pass
@classmethod
def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool:
return tensor.shape[-1] % cls.block_size == 0
@classmethod
def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
if not cls.can_quantize(tensor):
raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}")
if isinstance(tensor, LazyNumpyTensor):
return cls.__quantize_lazy(tensor)
else:
return cls.__quantize_array(tensor)
@classmethod
def dequantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray:
if isinstance(tensor, LazyNumpyTensor):
return cls.__dequantize_lazy(tensor)
else:
return cls.__dequantize_array(tensor)
class BF16(__Quant, qtype=GGMLQuantizationType.BF16):
@classmethod
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
n = blocks.view(np.uint32)
# force nan to quiet
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
# round to nearest even
n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
return n.astype(np.uint16).view(np.uint8)
@classmethod
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
return (blocks.view(np.int16).astype(np.int32) << 16).view(np.float32)
class Q8_0(__Quant, qtype=GGMLQuantizationType.Q8_0):
@classmethod
# Implementation of Q8_0 with bit-exact same results as reference implementation in ggml-quants.c
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
d = abs(blocks).max(axis=1, keepdims=True) / 127 d = abs(blocks).max(axis=1, keepdims=True) / 127
with np.errstate(divide="ignore"): with np.errstate(divide="ignore"):
@ -99,23 +202,12 @@ def __quantize_q8_0_rows(n: np.ndarray) -> np.ndarray:
# (n_blocks, block_size) # (n_blocks, block_size)
qs = qs.astype(np.int8).view(np.uint8) qs = qs.astype(np.int8).view(np.uint8)
assert d.shape[1] + qs.shape[1] == __q8_type_size return np.concatenate([d, qs], axis=1)
return np.concatenate([d, qs], axis=1).reshape(__quantize_q8_0_shape_change(shape)) @classmethod
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
d, x = np.split(blocks, [2], axis=1)
d = d.view(np.float16).astype(np.float32)
x = x.view(np.int8).astype(np.float32)
return (x * d)
def __quantize_q8_0_array(n: np.ndarray) -> np.ndarray:
return __apply_over_grouped_rows(__quantize_q8_0_rows, arr=n, otype=np.uint8, oshape=__quantize_q8_0_shape_change(n.shape))
__quantize_q8_0_lazy = LazyNumpyTensor._wrap_fn(
__quantize_q8_0_array,
meta_noop=(np.uint8, __quantize_q8_0_shape_change),
)
def quantize_q8_0(data: np.ndarray):
if type(data) is LazyNumpyTensor:
return __quantize_q8_0_lazy(data)
else:
return __quantize_q8_0_array(data)