missing cast and additional numpy 2.x fix
This commit is contained in:
parent
225ec48fe5
commit
e8e2b7e03f
1 changed files with 2 additions and 2 deletions
|
@ -27,11 +27,11 @@ def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizati
|
||||||
def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
|
def __compute_fp32_to_bf16(n: np.ndarray) -> np.ndarray:
|
||||||
n = n.astype(np.float32, copy=False).view(np.uint32)
|
n = n.astype(np.float32, copy=False).view(np.uint32)
|
||||||
# force nan to quiet
|
# force nan to quiet
|
||||||
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | (64 << 16), n)
|
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n)
|
||||||
# flush subnormals to zero
|
# flush subnormals to zero
|
||||||
n = np.where((n & 0x7f800000) == 0, n & np.uint32(0x80000000), n)
|
n = np.where((n & 0x7f800000) == 0, n & np.uint32(0x80000000), n)
|
||||||
# round to nearest even
|
# round to nearest even
|
||||||
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
|
n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16
|
||||||
return n.astype(np.uint16)
|
return n.astype(np.uint16)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue