gguf-py : Numpy dequantization for grid-based i-quants

This commit is contained in:
Francis Couture-Harpin 2024-08-09 23:47:31 -04:00
parent 5a9edda7ca
commit 73bc9350cd
2 changed files with 638 additions and 18 deletions

View file

@ -60,9 +60,13 @@ class GGMLQuants:
ctypes.POINTER(ctypes.c_float),
)
self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)
for t in (
"q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
"q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
"iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
"iq4_nl", "iq4_xs",
):
dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t)
@ -97,7 +101,12 @@ class GGMLQuants:
def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C")
result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], ctypes.cast(0, c_float_p))
if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
# TODO: is a column-wise sum of squares appropriate?
qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p)
else:
qw = ctypes.cast(0, c_float_p)
result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw)
assert result.size == result_size
return result
@ -116,8 +125,10 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType)
t2 = t2.reshape((-1, type_size))
x = t1.view(np.uint8) ^ t2.view(np.uint8)
diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
logger.debug(f"{diff_bits.shape=}")
num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
if num_bad_blocks == 0 and t1.shape == t2.shape:
logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
return True
logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)")
bad_block_id = np.argmax(diff_bits, axis=0)
logger.debug(f"Worst block id: {bad_block_id}")
@ -128,7 +139,7 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType)
return False
def do_test(libggml_path: Path):
def do_test(libggml_path: Path, quick: bool = False):
ggml_quants = GGMLQuants(libggml_path)
np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
@ -162,16 +173,15 @@ def do_test(libggml_path: Path):
rc = r.copy(order="C")
pyq = None
ggq = None
if has_quantize:
logger.debug(f"Quantizing to {qtype.name} with Python")
pyq = gguf.quants.quantize(rc, qtype)
logger.debug(f"Quantizing to {qtype.name} with C")
ggq = ggml_quants.quantize(rc, qtype)
logger.debug(f"Quantizing to {qtype.name} with C")
ggq = ggml_quants.quantize(rc, qtype)
if has_quantize:
assert pyq is not None
if qtype == GGMLQuantizationType.F16:
pyq = pyq.view(np.uint8)
quant_equal = compare_tensors(pyq, ggq, qtype)
@ -182,25 +192,46 @@ def do_test(libggml_path: Path):
logger.info(f"Quantization to {qtype.name} matches exactly ✅")
if has_dequantize:
logger.debug(f"Dequantizing from {qtype.name} with Python")
pydq = gguf.quants.dequantize(ggq, qtype)
logger.debug(f"Dequantizing from {qtype.name} with C")
ggdq = ggml_quants.dequantize(ggq, qtype)
if ggq is None and not quick:
logger.debug(f"Quantizing to {qtype.name} with C")
ggq = ggml_quants.quantize(rc, qtype)
if ggq is not None:
logger.debug(f"Dequantizing from {qtype.name} with Python")
pydq = gguf.quants.dequantize(ggq, qtype)
logger.debug(f"Dequantizing from {qtype.name} with C")
ggdq = ggml_quants.dequantize(ggq, qtype)
dequant_equal = compare_tensors(pydq, ggdq, qtype)
if not dequant_equal:
logger.error(f"Dequantization from {qtype.name} does not match ❌")
else:
logger.info(f"Dequantization from {qtype.name} matches exactly ✅")
rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype)
rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)
logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
pydq = gguf.quants.dequantize(rq, qtype)
logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
ggdq = ggml_quants.dequantize(rq, qtype)
dequant_equal = compare_tensors(pydq, ggdq, qtype)
if not dequant_equal:
logger.error(f"Dequantization from {qtype.name} does not match ❌")
logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌")
else:
logger.info(f"Dequantization from {qtype.name} matches exactly ✅")
logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG)
do_test(args.libggml)
do_test(args.libggml, args.quick)