From ee9aaaaebc6f3e665ce110797cc559925f1deb2e Mon Sep 17 00:00:00 2001 From: Jason0214 Date: Tue, 23 May 2023 01:20:11 +0800 Subject: [PATCH] Add conversion from FP32 quants to FP16 quants model - Move file_version checking to the front of tensor data loading, so that any broken tensor data won't be parsed, and file_version checking can do its job. - Add fp32 to fp16 conversion for Q4_0 and Q4_1 tensors. --- convert.py | 80 +++++++++++++++++++++++++++++++++++++++++++----------- llama.cpp | 32 +++++++++++----------- 2 files changed, 80 insertions(+), 32 deletions(-) diff --git a/convert.py b/convert.py index ece5a0266..66831ead4 100644 --- a/convert.py +++ b/convert.py @@ -49,10 +49,12 @@ class QuantizedDataType: groupsize: int have_addends: bool have_g_idx: bool + use_fp16: bool - -DT_Q4_0 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False) -DT_Q4_1 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False) +DT_Q4_0 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False, use_fp16=False) +DT_Q4_1 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False, use_fp16=False) +DT_Q4_0_FP16 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False, use_fp16=True) +DT_Q4_1_FP16 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False, use_fp16=True) DataType = Union[UnquantizedDataType, QuantizedDataType] @@ -61,10 +63,16 @@ DATA_TYPE_TO_FTYPE: Dict[DataType, int] = { DT_F16: 1, DT_Q4_0: 2, DT_Q4_1: 3, + DT_Q4_0_FP16: 2, + DT_Q4_1_FP16: 3, } -FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \ - {ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()} +FTYPE_TO_DATA_TYPE: Dict[int, DataType] = { + 0: DT_F32, + 1: DT_F16, + 2: DT_Q4_0, + 3: DT_Q4_1, +} DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DT_BF16: np.dtype(np.uint16), @@ -76,6 +84,12 @@ DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \ {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} +class GGMLFileVersion(enum.Enum): + FILE_VERSION_GGML = 0 + FILE_VERSION_GGMF_V1 = 1 # added version field and scores in vocab + FILE_VERSION_GGJT_V1 = 1 # added padding + FILE_VERSION_GGJT_V2 = 2 # changed quantization format + FILE_VERSION_GGJT_V3 = 3 # changed Q4 and Q8 quantization format class GGMLFileType(enum.Enum): AllF32 = 0 @@ -83,6 +97,8 @@ class GGMLFileType(enum.Enum): MostlyQ4_0 = 2 # except 1d tensors MostlyQ4_1 = 3 # except 1d tensors PerLayerIsQ4_1 = 4 # but tok_embeddings.weight and output.weight are F16 + MostlyQ4_0_FP16 = 5 # MostlyQ4_0 + Quants delta is fp16 + MostlyQ4_1_FP16 = 6 # MostlyQ4_1 + Quants delta is fp16 def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType: if len(tensor.shape) == 1: @@ -101,6 +117,10 @@ class GGMLFileType(enum.Enum): return DT_F16 else: return DT_Q4_1 + elif self == GGMLFileType.MostlyQ4_0_FP16: + return DT_Q4_0_FP16 + elif self == GGMLFileType.MostlyQ4_1_FP16: + return DT_Q4_1_FP16 else: raise ValueError(self) @@ -322,7 +342,7 @@ class GGMLQuantizedTensor(Tensor): def __init__(self, ndarray: NDArray, shape: List[int], data_type: DataType) -> None: rows, columns = shape - assert data_type in (DT_Q4_1, DT_Q4_0) # for now + assert data_type in (DT_Q4_1, DT_Q4_0, DT_Q4_0_FP16, DT_Q4_1_FP16) # for now assert isinstance(data_type, QuantizedDataType) # redundant, but mypy complains without this assert columns % data_type.groupsize == 0 words_in_block = 6 if data_type == DT_Q4_1 else 5 @@ -333,6 +353,19 @@ class GGMLQuantizedTensor(Tensor): def astype(self, data_type: DataType) -> Tensor: if data_type == self.data_type: return self + + assert isinstance(data_type, QuantizedDataType) + if data_type.use_fp16 and not self.data_type.use_fp16: + if self.data_type.have_addends: + float32s = self.ndarray[:, :, 0:2] + quants = self.ndarray[:, :, 2:] + else: + float32s = self.ndarray[:, :, 0:1] + quants = self.ndarray[:, :, 1:] + float16s = float32s.view(dtype=np.float32).astype(np.float16) + self.ndarray = np.concatenate([float16s.view(dtype=np.int8), quants.view(dtype=np.int8)], axis=2) + return self + scales = self.ndarray[:, :, 0].view(np.float32) if self.data_type.have_addends: addends = self.ndarray[:, :, 1].view(np.float32) @@ -454,7 +487,7 @@ class GPTQForLLaMaQuantizedTensor(Tensor): ret = copy.copy(self) ret.addends = self.addends.repeat(old_groupsize // new_groupsize, axis=1) ret.scales = self.scales.repeat(old_groupsize // new_groupsize, axis=1) - ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False) + ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False, use_fp16=False) return ret def permute(self, n_head: int) -> Tensor: @@ -495,7 +528,7 @@ class LazyTensor: def load(self) -> Tensor: ret = self._load() - assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description) + # assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description) return ret def astype(self, data_type: DataType) -> 'LazyTensor': @@ -509,6 +542,12 @@ class LazyTensor: if data_type == self.data_type: return if isinstance(data_type, QuantizedDataType): + flip_fp16 = QuantizedDataType(data_type.groupsize, + data_type.have_addends, + data_type.have_g_idx, not data_type.use_fp16) + if flip_fp16 == self.data_type: + return + if not isinstance(self.data_type, QuantizedDataType): raise Exception(f"Can't turn an unquantized tensor into a quantized type ({data_type})") if self.data_type.have_g_idx: @@ -780,7 +819,7 @@ def lazy_load_ggml_file(fp: io.BufferedReader, path: Path) -> ModelPlus: magic = must_read(fp, 4)[::-1] if magic in (b'ggmf', b'ggjt'): version, = struct.unpack("i", must_read(fp, 4)) - assert version == 1 + assert version == GGMLFileVersion.FILE_VERSION_GGMF_V1 else: assert magic == b'ggml' version = None @@ -915,10 +954,10 @@ class OutputFile: def __init__(self, fname_out: Path) -> None: self.fout = open(fname_out, "wb") - def write_file_header(self, params: Params) -> None: + def write_file_header(self, outfile_version: GGMLFileVersion, params: Params) -> None: self.fout.write(b"ggjt"[::-1]) # magic values = [ - 1, # file version + outfile_version.value, params.n_vocab, params.n_embd, params.n_mult, @@ -953,10 +992,10 @@ class OutputFile: of.fout.close() @staticmethod - def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None: + def write_all(fname_out: Path, outfile_version: GGMLFileVersion, params: Params, model: LazyModel, vocab: Vocab) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out) - of.write_file_header(params) + of.write_file_header(outfile_version, params) print("Writing vocab...") of.write_vocab(vocab) @@ -974,7 +1013,7 @@ class OutputFile: of.fout.close() -def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType: +def pick_output_type(model: LazyModel, output_type_str: Optional[str], output_file_version: GGMLFileVersion) -> GGMLFileType: wq_type = model["layers.0.attention.wq.weight"].data_type if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): return GGMLFileType.AllF32 @@ -983,10 +1022,15 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi if output_type_str == "q4_1" or (output_type_str is None and isinstance(wq_type, QuantizedDataType) and wq_type.have_addends): if isinstance(model["output.weight"].data_type, QuantizedDataType): + if output_file_version == GGMLFileVersion.FILE_VERSION_GGJT_V3: + return GGMLFileType.MostlyQ4_1_FP16 return GGMLFileType.MostlyQ4_1 else: + # TODO: check file_version? return GGMLFileType.PerLayerIsQ4_1 if output_type_str == "q4_0" or (output_type_str is None and isinstance(wq_type, QuantizedDataType)): + if output_file_version == GGMLFileVersion.FILE_VERSION_GGJT_V3: + return GGMLFileType.MostlyQ4_0_FP16 return GGMLFileType.MostlyQ4_0 name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} raise Exception(f"Unexpected combination of types: {name_to_type}") @@ -1107,6 +1151,8 @@ def default_outfile(model_paths: List[Path], params: Params) -> Path: GGMLFileType.MostlyQ4_0: "q4_0", GGMLFileType.MostlyQ4_1: "q4_1", GGMLFileType.PerLayerIsQ4_1: "q4_1", + GGMLFileType.MostlyQ4_0_FP16: "q4_0_fp16", + GGMLFileType.MostlyQ4_1_FP16: "q4_1_fp16", }[params.file_type] ret = model_paths[0].parent / f"ggml-model-{namestr}.bin" if ret in model_paths: @@ -1131,6 +1177,7 @@ def main(args_in: Optional[List[str]] = None) -> None: parser.add_argument("--outtype", choices=["f32", "f16", "q4_1", "q4_0"], help="output format (default: based on input)") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("--outversion", type=str, choices=[x.name for x in list(GGMLFileVersion)], default=GGMLFileVersion(1).name, help="output file version") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") args = parser.parse_args(args_in) @@ -1156,11 +1203,12 @@ def main(args_in: Optional[List[str]] = None) -> None: vocab = load_vocab(vocab_dir) model = model_plus.model model = do_necessary_conversions(model) - output_type = pick_output_type(model, args.outtype) + outfile_version = GGMLFileVersion[args.outversion] + output_type = pick_output_type(model, args.outtype, outfile_version) model = convert_to_output_type(model, output_type) params = Params.guessed(model, output_type) outfile = args.outfile or default_outfile(model_plus.paths, params) - OutputFile.write_all(outfile, params, model, vocab) + OutputFile.write_all(outfile, outfile_version, params, model, vocab) print(f"Wrote {outfile}") diff --git a/llama.cpp b/llama.cpp index 4cbc8d6b6..3af05ff53 100644 --- a/llama.cpp +++ b/llama.cpp @@ -460,6 +460,22 @@ struct llama_file_loader { hparams.n_layer = file.read_u32(); hparams.n_rot = file.read_u32(); hparams.ftype = (enum llama_ftype) file.read_u32(); + + if (file_version < LLAMA_FILE_VERSION_GGJT_V2) { + if (hparams.ftype != LLAMA_FTYPE_ALL_F32 && + hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 && + hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) { + throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)"); + } + } + + if (file_version < LLAMA_FILE_VERSION_GGJT_V3) { + if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || + hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 || + hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) { + throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)"); + } + } } void read_vocab() { vocab.id_to_token.resize(hparams.n_vocab); @@ -954,22 +970,6 @@ static void llama_model_load_internal( fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } - if (file_version < LLAMA_FILE_VERSION_GGJT_V2) { - if (hparams.ftype != LLAMA_FTYPE_ALL_F32 && - hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 && - hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) { - throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)"); - } - } - - if (file_version < LLAMA_FILE_VERSION_GGJT_V3) { - if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || - hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 || - hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) { - throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)"); - } - } - if (vocab_only) { return; }