add truncate_bf16
This commit is contained in:
parent
10ceba354a
commit
6a52bfe332
1 changed files with 20 additions and 0 deletions
|
@ -35,6 +35,12 @@ def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
|
|||
return n.astype(np.int16)
|
||||
|
||||
|
||||
# for fp32 values that are just extended bf16
|
||||
def __truncate_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
|
||||
n = n.astype(np.float32, copy=False).view(np.uint32) >> 16
|
||||
return n.astype(np.uint16)
|
||||
|
||||
|
||||
# This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time
|
||||
def __apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray:
|
||||
rows = arr.reshape((-1, arr.shape[-1]))
|
||||
|
@ -62,6 +68,20 @@ def quantize_bf16(n: np.ndarray):
|
|||
return __quantize_bf16_array(n)
|
||||
|
||||
|
||||
def __truncate_bf16_array(n: np.ndarray) -> np.ndarray:
|
||||
return __apply_over_grouped_rows(__truncate_fp32_to_bf16, arr=n, otype=np.uint16, oshape=n.shape)
|
||||
|
||||
|
||||
__truncate_bf16_lazy = LazyNumpyTensor._wrap_fn(__truncate_bf16_array, meta_noop=np.uint16)
|
||||
|
||||
|
||||
def truncate_bf16(n: np.ndarray):
|
||||
if type(n) is LazyNumpyTensor:
|
||||
return __truncate_bf16_lazy(n)
|
||||
else:
|
||||
return __truncate_bf16_array(n)
|
||||
|
||||
|
||||
__q8_block_size, __q8_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q8_0]
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue