Use Q8_0 quantization from gguf module.

This makes tensors exactly as in https://huggingface.co/Arki05/Grok-1-GGUF/tree/main/Q8_0
This commit is contained in:
Heiner 2024-05-23 11:10:59 +02:00
parent f177b6596c
commit e2f13a3346

View file

@ -124,7 +124,7 @@ def get_weights(fn):
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q8_0 in ggml.c
# equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero)
assert tensor.shape[1] % GGML_QK8_0 == 0
tensor = tensor.reshape(-1, GGML_QK8_0)
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
@ -135,7 +135,7 @@ def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_0 in ggml.c
# equivalent to ggml_quantize_q4_0 in ggml.c (modulo rounding away from zero)
assert tensor.shape[1] % GGML_QK4_0 == 0
tensor = tensor.reshape(-1, GGML_QK4_0)
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
@ -150,7 +150,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_1 in ggml.c
# equivalent to ggml_quantize_q4_1 in ggml.c (modulo rounding away from zero)
assert tensor.shape[1] % GGML_QK4_1 == 0
tensor = tensor.reshape(-1, GGML_QK4_1)
abs_max_indices = tensor.max(dim=-1, keepdim=True).indices
@ -170,13 +170,14 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
def maybe_quantize_tensor(tensor, ggml_type):
assert tensor.dtype == torch.float32
if ggml_type == gguf.GGMLQuantizationType.F32:
return tensor.float()
elif ggml_type == gguf.GGMLQuantizationType.F16:
return tensor.half()
elif ggml_type == gguf.GGMLQuantizationType.Q8_0:
return quantize_q8_0(tensor)
if tensor.device.type == "meta":
return quantize_q8_0(tensor) # Cannot convert into numpy array.
return torch.from_numpy(gguf.quantize_q8_0(tensor.numpy()))
elif ggml_type == gguf.GGMLQuantizationType.Q4_0:
return quantize_q4_0(tensor)
elif ggml_type == gguf.GGMLQuantizationType.Q4_1: