gguf-py : fix and simplify quantized shape round-trip

This commit is contained in:
Francis Couture-Harpin 2024-05-22 23:40:41 -04:00
parent cd93a28cb1
commit 2ff601fc32
5 changed files with 27 additions and 13 deletions

View file

@ -26,6 +26,8 @@ from .constants import (
TokenType,
)
from .quants import quant_shape_from_byte_shape
logger = logging.getLogger(__name__)
@ -229,10 +231,7 @@ class GGUFWriter:
else:
dtype = raw_dtype
if tensor_dtype == np.uint8:
block_size, type_size = GGML_QUANT_SIZES[raw_dtype]
if tensor_shape[-1] % type_size != 0:
raise ValueError(f"Quantized tensor row size ({tensor_shape[-1]}) is not a multiple of {dtype.name} type size ({type_size})")
tensor_shape = tuple(tensor_shape[:-1]) + (tensor_shape[-1] // type_size * block_size,)
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
n_dims = len(tensor_shape)
self.ti_data += self._pack("I", n_dims)
for i in range(n_dims):