From 5e27e7e11c5e5feb012a8ce1a25f1b78a0766b45 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 2 Aug 2024 16:14:49 -0400 Subject: [PATCH] convert_hf : simplify internal quantization type selection --- convert_hf_to_gguf.py | 102 ++++++++++++++++---------------------- gguf-py/gguf/constants.py | 51 ++++++++++++++++++- 2 files changed, 94 insertions(+), 59 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8b33c30d9..bfdf29a64 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -251,12 +251,7 @@ class Model: return [(self.map_tensor_name(name), data_torch)] - def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: - del name, new_name, bid, n_dims # unused - - return False - - def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused return False @@ -285,55 +280,42 @@ class Model: for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): data: np.ndarray # type hint n_dims = len(data.shape) - data_dtype = data.dtype - data_qtype: gguf.GGMLQuantizationType | None = None - - # when both are True, f32 should win - extra_f32 = self.extra_f32_tensors(name, new_name, bid, n_dims) - extra_f16 = self.extra_f16_tensors(name, new_name, bid, n_dims) + data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors - # Conditions should closely match those in llama_model_quantize_internal in llama.cpp - extra_f32 = any(cond for cond in ( - extra_f32, - n_dims == 1, - new_name.endswith("_norm.weight"), - )) - - # Some tensor types are always in float32 - extra_f32 = extra_f32 or any(self.match_model_tensor_name(new_name, key, bid) for key in ( - gguf.MODEL_TENSOR.FFN_GATE_INP, - gguf.MODEL_TENSOR.POS_EMBD, - gguf.MODEL_TENSOR.TOKEN_TYPES, - )) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - extra_f16 = any(cond for cond in ( - extra_f16, - (name.endswith(".weight") and n_dims >= 2), - )) - - if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: - if self.ftype == gguf.LlamaFileType.MOSTLY_BF16: - data = gguf.quantize_bf16(data) - assert data.dtype == np.uint16 - data_qtype = gguf.GGMLQuantizationType.BF16 - - elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): - data = gguf.quantize_q8_0(data) - assert data.dtype == np.uint8 - data_qtype = gguf.GGMLQuantizationType.Q8_0 - - else: # default to float16 for quantized tensors - if data_dtype != np.float16: - data = data.astype(np.float16) - data_qtype = gguf.GGMLQuantizationType.F16 - - if data_qtype is None: # by default, convert to float32 - if data_dtype != np.float32: - data = data.astype(np.float32) + if n_dims <= 1 or new_name.endswith("_norm.weight"): data_qtype = gguf.GGMLQuantizationType.F32 + # Conditions should closely match those in llama_model_quantize_internal in llama.cpp + # Some tensor types are always in float32 + if data_qtype is False and ( + any( + self.match_model_tensor_name(new_name, key, bid) + for key in ( + gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.POS_EMBD, + gguf.MODEL_TENSOR.TOKEN_TYPES, + ) + ) + or not name.endswith(".weight") + ): + data_qtype = gguf.GGMLQuantizationType.F32 + + if isinstance(data_qtype, bool): + data_qtype = gguf.LlamaFileTypeMap.get(self.ftype, gguf.GGMLQuantizationType.F32) + + if data_qtype == gguf.GGMLQuantizationType.Q8_0: + if gguf.quants.Q8_0.can_quantize(data): + data = gguf.quants.Q8_0.quantize(data) + else: # fallback to f16 + data_qtype = gguf.GGMLQuantizationType.F16 + if data_qtype == gguf.GGMLQuantizationType.BF16: + data = gguf.quants.BF16.quantize(data) + if data_qtype == gguf.GGMLQuantizationType.F16: + data = data.astype(np.float16, copy=False) + if data_qtype == gguf.GGMLQuantizationType.F32: + data = data.astype(np.float32, copy=False) + shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape # reverse shape to make it similar to the internal ggml dimension order @@ -1765,7 +1747,7 @@ class DbrxModel(Model): return [(new_name, data_torch)] - def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid # unused return n_dims > 1 @@ -2680,18 +2662,22 @@ class MambaModel(Model): return [(new_name, data_torch)] - def extra_f32_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool: - del n_dims # unused - - return bid is not None and new_name in ( - self.format_tensor_name(n, bid, ".weight" if name.endswith(".weight") else "") for n in [ + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + if bid is not None and new_name in ( + self.format_tensor_name( + n, bid, ".weight" if name.endswith(".weight") else "" + ) + for n in [ gguf.MODEL_TENSOR.SSM_CONV1D, gguf.MODEL_TENSOR.SSM_X, gguf.MODEL_TENSOR.SSM_DT, gguf.MODEL_TENSOR.SSM_A, gguf.MODEL_TENSOR.SSM_D, ] - ) + ): + return gguf.GGMLQuantizationType.F32 + + return super().tensor_force_quant(name, new_name, bid, n_dims) @Model.register("CohereForCausalLM") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index e343c2ef1..82febc4b6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1145,6 +1145,9 @@ class GGMLQuantizationType(IntEnum): F64 = 28 IQ1_M = 29 BF16 = 30 + Q4_0_4_4 = 31 + Q4_0_4_8 = 32 + Q4_0_8_8 = 33 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -1157,7 +1160,7 @@ class LlamaFileType(IntEnum): MOSTLY_F16 = 1 # except 1d tensors MOSTLY_Q4_0 = 2 # except 1d tensors MOSTLY_Q4_1 = 3 # except 1d tensors - MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 + # MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 # MOSTLY_Q4_2 = 5 # support has been removed # MOSTLY_Q4_3 = 6 # support has been removed MOSTLY_Q8_0 = 7 # except 1d tensors @@ -1186,10 +1189,53 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors + MOSTLY_Q4_0_4_4 = 33 # except 1d tensors + MOSTLY_Q4_0_4_8 = 34 # except 1d tensors + MOSTLY_Q4_0_8_8 = 35 # except 1d tensors GUESSED = 1024 # not specified in the model file +# Default quantization type for each file type +# Keep this the same as in llama_model_quantize_internal from llama.cpp +LlamaFileTypeMap: dict[LlamaFileType, GGMLQuantizationType] = { + LlamaFileType.MOSTLY_Q4_0: GGMLQuantizationType.Q4_0, + LlamaFileType.MOSTLY_Q4_1: GGMLQuantizationType.Q4_1, + LlamaFileType.MOSTLY_Q5_0: GGMLQuantizationType.Q5_0, + LlamaFileType.MOSTLY_Q5_1: GGMLQuantizationType.Q5_1, + LlamaFileType.MOSTLY_Q8_0: GGMLQuantizationType.Q8_0, + LlamaFileType.MOSTLY_F16: GGMLQuantizationType.F16, + LlamaFileType.MOSTLY_BF16: GGMLQuantizationType.BF16, + LlamaFileType.ALL_F32: GGMLQuantizationType.F32, + + # K-quants + LlamaFileType.MOSTLY_Q2_K_S: GGMLQuantizationType.Q2_K, + LlamaFileType.MOSTLY_Q2_K: GGMLQuantizationType.Q2_K, + LlamaFileType.MOSTLY_IQ3_XS: GGMLQuantizationType.IQ3_S, + LlamaFileType.MOSTLY_Q3_K_S: GGMLQuantizationType.Q3_K, + LlamaFileType.MOSTLY_Q3_K_M: GGMLQuantizationType.Q3_K, + LlamaFileType.MOSTLY_Q3_K_L: GGMLQuantizationType.Q3_K, + LlamaFileType.MOSTLY_Q4_K_S: GGMLQuantizationType.Q4_K, + LlamaFileType.MOSTLY_Q4_K_M: GGMLQuantizationType.Q4_K, + LlamaFileType.MOSTLY_Q5_K_S: GGMLQuantizationType.Q5_K, + LlamaFileType.MOSTLY_Q5_K_M: GGMLQuantizationType.Q5_K, + LlamaFileType.MOSTLY_Q6_K: GGMLQuantizationType.Q6_K, + LlamaFileType.MOSTLY_IQ2_XXS: GGMLQuantizationType.IQ2_XXS, + LlamaFileType.MOSTLY_IQ2_XS: GGMLQuantizationType.IQ2_XS, + LlamaFileType.MOSTLY_IQ2_S: GGMLQuantizationType.IQ2_XS, + LlamaFileType.MOSTLY_IQ2_M: GGMLQuantizationType.IQ2_S, + LlamaFileType.MOSTLY_IQ3_XXS: GGMLQuantizationType.IQ3_XXS, + LlamaFileType.MOSTLY_IQ1_S: GGMLQuantizationType.IQ1_S, + LlamaFileType.MOSTLY_IQ1_M: GGMLQuantizationType.IQ1_M, + LlamaFileType.MOSTLY_IQ4_NL: GGMLQuantizationType.IQ4_NL, + LlamaFileType.MOSTLY_IQ4_XS: GGMLQuantizationType.IQ4_XS, + LlamaFileType.MOSTLY_IQ3_S: GGMLQuantizationType.IQ3_S, + LlamaFileType.MOSTLY_IQ3_M: GGMLQuantizationType.IQ3_S, + LlamaFileType.MOSTLY_Q4_0_4_4: GGMLQuantizationType.Q4_0_4_4, + LlamaFileType.MOSTLY_Q4_0_4_8: GGMLQuantizationType.Q4_0_4_8, + LlamaFileType.MOSTLY_Q4_0_8_8: GGMLQuantizationType.Q4_0_8_8, +} + class GGUFEndian(IntEnum): LITTLE = 0 BIG = 1 @@ -1259,6 +1305,9 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.BF16: (1, 2), + GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16), + GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16), + GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), }