From 6a52bfe33292e23e5dc501c6b71f692ca45e277b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Mon, 10 Jun 2024 04:26:55 +0200 Subject: [PATCH] add truncate_bf16 --- gguf-py/gguf/quants.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index b22eec166..31f689dd7 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -35,6 +35,12 @@ def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray: return n.astype(np.int16) +# for fp32 values that are just extended bf16 +def __truncate_fp32_to_bf16(n: np.ndarray) -> np.ndarray: + n = n.astype(np.float32, copy=False).view(np.uint32) >> 16 + return n.astype(np.uint16) + + # This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray: rows = arr.reshape((-1, arr.shape[-1])) @@ -62,6 +68,20 @@ def quantize_bf16(n: np.ndarray): return __quantize_bf16_array(n) +def __truncate_bf16_array(n: np.ndarray) -> np.ndarray: + return __apply_over_grouped_rows(__truncate_fp32_to_bf16, arr=n, otype=np.uint16, oshape=n.shape) + + +__truncate_bf16_lazy = LazyNumpyTensor._wrap_fn(__truncate_bf16_array, meta_noop=np.uint16) + + +def truncate_bf16(n: np.ndarray): + if type(n) is LazyNumpyTensor: + return __truncate_bf16_lazy(n) + else: + return __truncate_bf16_array(n) + + __q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0]