Add conversion from FP32 quants to FP16 quants model

- Move file_version checking to the front of tensor data loading, so
  that any broken tensor data won't be parsed, and file_version checking
can do its job.
- Add fp32 to fp16 conversion for Q4_0 and Q4_1 tensors.
This commit is contained in:
Jason0214 2023-05-23 01:20:11 +08:00
parent 7e4ea5beff
commit ee9aaaaebc
2 changed files with 80 additions and 32 deletions

View file

@ -49,10 +49,12 @@ class QuantizedDataType:
groupsize: int groupsize: int
have_addends: bool have_addends: bool
have_g_idx: bool have_g_idx: bool
use_fp16: bool
DT_Q4_0 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False, use_fp16=False)
DT_Q4_0 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False) DT_Q4_1 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False, use_fp16=False)
DT_Q4_1 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False) DT_Q4_0_FP16 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False, use_fp16=True)
DT_Q4_1_FP16 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False, use_fp16=True)
DataType = Union[UnquantizedDataType, QuantizedDataType] DataType = Union[UnquantizedDataType, QuantizedDataType]
@ -61,10 +63,16 @@ DATA_TYPE_TO_FTYPE: Dict[DataType, int] = {
DT_F16: 1, DT_F16: 1,
DT_Q4_0: 2, DT_Q4_0: 2,
DT_Q4_1: 3, DT_Q4_1: 3,
DT_Q4_0_FP16: 2,
DT_Q4_1_FP16: 3,
} }
FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \ FTYPE_TO_DATA_TYPE: Dict[int, DataType] = {
{ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()} 0: DT_F32,
1: DT_F16,
2: DT_Q4_0,
3: DT_Q4_1,
}
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
DT_BF16: np.dtype(np.uint16), DT_BF16: np.dtype(np.uint16),
@ -76,6 +84,12 @@ DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \ NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
class GGMLFileVersion(enum.Enum):
FILE_VERSION_GGML = 0
FILE_VERSION_GGMF_V1 = 1 # added version field and scores in vocab
FILE_VERSION_GGJT_V1 = 1 # added padding
FILE_VERSION_GGJT_V2 = 2 # changed quantization format
FILE_VERSION_GGJT_V3 = 3 # changed Q4 and Q8 quantization format
class GGMLFileType(enum.Enum): class GGMLFileType(enum.Enum):
AllF32 = 0 AllF32 = 0
@ -83,6 +97,8 @@ class GGMLFileType(enum.Enum):
MostlyQ4_0 = 2 # except 1d tensors MostlyQ4_0 = 2 # except 1d tensors
MostlyQ4_1 = 3 # except 1d tensors MostlyQ4_1 = 3 # except 1d tensors
PerLayerIsQ4_1 = 4 # but tok_embeddings.weight and output.weight are F16 PerLayerIsQ4_1 = 4 # but tok_embeddings.weight and output.weight are F16
MostlyQ4_0_FP16 = 5 # MostlyQ4_0 + Quants delta is fp16
MostlyQ4_1_FP16 = 6 # MostlyQ4_1 + Quants delta is fp16
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType: def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
if len(tensor.shape) == 1: if len(tensor.shape) == 1:
@ -101,6 +117,10 @@ class GGMLFileType(enum.Enum):
return DT_F16 return DT_F16
else: else:
return DT_Q4_1 return DT_Q4_1
elif self == GGMLFileType.MostlyQ4_0_FP16:
return DT_Q4_0_FP16
elif self == GGMLFileType.MostlyQ4_1_FP16:
return DT_Q4_1_FP16
else: else:
raise ValueError(self) raise ValueError(self)
@ -322,7 +342,7 @@ class GGMLQuantizedTensor(Tensor):
def __init__(self, ndarray: NDArray, shape: List[int], data_type: DataType) -> None: def __init__(self, ndarray: NDArray, shape: List[int], data_type: DataType) -> None:
rows, columns = shape rows, columns = shape
assert data_type in (DT_Q4_1, DT_Q4_0) # for now assert data_type in (DT_Q4_1, DT_Q4_0, DT_Q4_0_FP16, DT_Q4_1_FP16) # for now
assert isinstance(data_type, QuantizedDataType) # redundant, but mypy complains without this assert isinstance(data_type, QuantizedDataType) # redundant, but mypy complains without this
assert columns % data_type.groupsize == 0 assert columns % data_type.groupsize == 0
words_in_block = 6 if data_type == DT_Q4_1 else 5 words_in_block = 6 if data_type == DT_Q4_1 else 5
@ -333,6 +353,19 @@ class GGMLQuantizedTensor(Tensor):
def astype(self, data_type: DataType) -> Tensor: def astype(self, data_type: DataType) -> Tensor:
if data_type == self.data_type: if data_type == self.data_type:
return self return self
assert isinstance(data_type, QuantizedDataType)
if data_type.use_fp16 and not self.data_type.use_fp16:
if self.data_type.have_addends:
float32s = self.ndarray[:, :, 0:2]
quants = self.ndarray[:, :, 2:]
else:
float32s = self.ndarray[:, :, 0:1]
quants = self.ndarray[:, :, 1:]
float16s = float32s.view(dtype=np.float32).astype(np.float16)
self.ndarray = np.concatenate([float16s.view(dtype=np.int8), quants.view(dtype=np.int8)], axis=2)
return self
scales = self.ndarray[:, :, 0].view(np.float32) scales = self.ndarray[:, :, 0].view(np.float32)
if self.data_type.have_addends: if self.data_type.have_addends:
addends = self.ndarray[:, :, 1].view(np.float32) addends = self.ndarray[:, :, 1].view(np.float32)
@ -454,7 +487,7 @@ class GPTQForLLaMaQuantizedTensor(Tensor):
ret = copy.copy(self) ret = copy.copy(self)
ret.addends = self.addends.repeat(old_groupsize // new_groupsize, axis=1) ret.addends = self.addends.repeat(old_groupsize // new_groupsize, axis=1)
ret.scales = self.scales.repeat(old_groupsize // new_groupsize, axis=1) ret.scales = self.scales.repeat(old_groupsize // new_groupsize, axis=1)
ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False) ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False, use_fp16=False)
return ret return ret
def permute(self, n_head: int) -> Tensor: def permute(self, n_head: int) -> Tensor:
@ -495,7 +528,7 @@ class LazyTensor:
def load(self) -> Tensor: def load(self) -> Tensor:
ret = self._load() ret = self._load()
assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description) # assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
return ret return ret
def astype(self, data_type: DataType) -> 'LazyTensor': def astype(self, data_type: DataType) -> 'LazyTensor':
@ -509,6 +542,12 @@ class LazyTensor:
if data_type == self.data_type: if data_type == self.data_type:
return return
if isinstance(data_type, QuantizedDataType): if isinstance(data_type, QuantizedDataType):
flip_fp16 = QuantizedDataType(data_type.groupsize,
data_type.have_addends,
data_type.have_g_idx, not data_type.use_fp16)
if flip_fp16 == self.data_type:
return
if not isinstance(self.data_type, QuantizedDataType): if not isinstance(self.data_type, QuantizedDataType):
raise Exception(f"Can't turn an unquantized tensor into a quantized type ({data_type})") raise Exception(f"Can't turn an unquantized tensor into a quantized type ({data_type})")
if self.data_type.have_g_idx: if self.data_type.have_g_idx:
@ -780,7 +819,7 @@ def lazy_load_ggml_file(fp: io.BufferedReader, path: Path) -> ModelPlus:
magic = must_read(fp, 4)[::-1] magic = must_read(fp, 4)[::-1]
if magic in (b'ggmf', b'ggjt'): if magic in (b'ggmf', b'ggjt'):
version, = struct.unpack("i", must_read(fp, 4)) version, = struct.unpack("i", must_read(fp, 4))
assert version == 1 assert version == GGMLFileVersion.FILE_VERSION_GGMF_V1
else: else:
assert magic == b'ggml' assert magic == b'ggml'
version = None version = None
@ -915,10 +954,10 @@ class OutputFile:
def __init__(self, fname_out: Path) -> None: def __init__(self, fname_out: Path) -> None:
self.fout = open(fname_out, "wb") self.fout = open(fname_out, "wb")
def write_file_header(self, params: Params) -> None: def write_file_header(self, outfile_version: GGMLFileVersion, params: Params) -> None:
self.fout.write(b"ggjt"[::-1]) # magic self.fout.write(b"ggjt"[::-1]) # magic
values = [ values = [
1, # file version outfile_version.value,
params.n_vocab, params.n_vocab,
params.n_embd, params.n_embd,
params.n_mult, params.n_mult,
@ -953,10 +992,10 @@ class OutputFile:
of.fout.close() of.fout.close()
@staticmethod @staticmethod
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None: def write_all(fname_out: Path, outfile_version: GGMLFileVersion, params: Params, model: LazyModel, vocab: Vocab) -> None:
check_vocab_size(params, vocab) check_vocab_size(params, vocab)
of = OutputFile(fname_out) of = OutputFile(fname_out)
of.write_file_header(params) of.write_file_header(outfile_version, params)
print("Writing vocab...") print("Writing vocab...")
of.write_vocab(vocab) of.write_vocab(vocab)
@ -974,7 +1013,7 @@ class OutputFile:
of.fout.close() of.fout.close()
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType: def pick_output_type(model: LazyModel, output_type_str: Optional[str], output_file_version: GGMLFileVersion) -> GGMLFileType:
wq_type = model["layers.0.attention.wq.weight"].data_type wq_type = model["layers.0.attention.wq.weight"].data_type
if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
return GGMLFileType.AllF32 return GGMLFileType.AllF32
@ -983,10 +1022,15 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
if output_type_str == "q4_1" or (output_type_str is None and isinstance(wq_type, QuantizedDataType) and if output_type_str == "q4_1" or (output_type_str is None and isinstance(wq_type, QuantizedDataType) and
wq_type.have_addends): wq_type.have_addends):
if isinstance(model["output.weight"].data_type, QuantizedDataType): if isinstance(model["output.weight"].data_type, QuantizedDataType):
if output_file_version == GGMLFileVersion.FILE_VERSION_GGJT_V3:
return GGMLFileType.MostlyQ4_1_FP16
return GGMLFileType.MostlyQ4_1 return GGMLFileType.MostlyQ4_1
else: else:
# TODO: check file_version?
return GGMLFileType.PerLayerIsQ4_1 return GGMLFileType.PerLayerIsQ4_1
if output_type_str == "q4_0" or (output_type_str is None and isinstance(wq_type, QuantizedDataType)): if output_type_str == "q4_0" or (output_type_str is None and isinstance(wq_type, QuantizedDataType)):
if output_file_version == GGMLFileVersion.FILE_VERSION_GGJT_V3:
return GGMLFileType.MostlyQ4_0_FP16
return GGMLFileType.MostlyQ4_0 return GGMLFileType.MostlyQ4_0
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
raise Exception(f"Unexpected combination of types: {name_to_type}") raise Exception(f"Unexpected combination of types: {name_to_type}")
@ -1107,6 +1151,8 @@ def default_outfile(model_paths: List[Path], params: Params) -> Path:
GGMLFileType.MostlyQ4_0: "q4_0", GGMLFileType.MostlyQ4_0: "q4_0",
GGMLFileType.MostlyQ4_1: "q4_1", GGMLFileType.MostlyQ4_1: "q4_1",
GGMLFileType.PerLayerIsQ4_1: "q4_1", GGMLFileType.PerLayerIsQ4_1: "q4_1",
GGMLFileType.MostlyQ4_0_FP16: "q4_0_fp16",
GGMLFileType.MostlyQ4_1_FP16: "q4_1_fp16",
}[params.file_type] }[params.file_type]
ret = model_paths[0].parent / f"ggml-model-{namestr}.bin" ret = model_paths[0].parent / f"ggml-model-{namestr}.bin"
if ret in model_paths: if ret in model_paths:
@ -1131,6 +1177,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--outtype", choices=["f32", "f16", "q4_1", "q4_0"], help="output format (default: based on input)") parser.add_argument("--outtype", choices=["f32", "f16", "q4_1", "q4_0"], help="output format (default: based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("--outversion", type=str, choices=[x.name for x in list(GGMLFileVersion)], default=GGMLFileVersion(1).name, help="output file version")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
@ -1156,11 +1203,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
vocab = load_vocab(vocab_dir) vocab = load_vocab(vocab_dir)
model = model_plus.model model = model_plus.model
model = do_necessary_conversions(model) model = do_necessary_conversions(model)
output_type = pick_output_type(model, args.outtype) outfile_version = GGMLFileVersion[args.outversion]
output_type = pick_output_type(model, args.outtype, outfile_version)
model = convert_to_output_type(model, output_type) model = convert_to_output_type(model, output_type)
params = Params.guessed(model, output_type) params = Params.guessed(model, output_type)
outfile = args.outfile or default_outfile(model_plus.paths, params) outfile = args.outfile or default_outfile(model_plus.paths, params)
OutputFile.write_all(outfile, params, model, vocab) OutputFile.write_all(outfile, outfile_version, params, model, vocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")

View file

@ -460,6 +460,22 @@ struct llama_file_loader {
hparams.n_layer = file.read_u32(); hparams.n_layer = file.read_u32();
hparams.n_rot = file.read_u32(); hparams.n_rot = file.read_u32();
hparams.ftype = (enum llama_ftype) file.read_u32(); hparams.ftype = (enum llama_ftype) file.read_u32();
if (file_version < LLAMA_FILE_VERSION_GGJT_V2) {
if (hparams.ftype != LLAMA_FTYPE_ALL_F32 &&
hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 &&
hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) {
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)");
}
}
if (file_version < LLAMA_FILE_VERSION_GGJT_V3) {
if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 ||
hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) {
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)");
}
}
} }
void read_vocab() { void read_vocab() {
vocab.id_to_token.resize(hparams.n_vocab); vocab.id_to_token.resize(hparams.n_vocab);
@ -954,22 +970,6 @@ static void llama_model_load_internal(
fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type));
} }
if (file_version < LLAMA_FILE_VERSION_GGJT_V2) {
if (hparams.ftype != LLAMA_FTYPE_ALL_F32 &&
hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 &&
hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) {
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)");
}
}
if (file_version < LLAMA_FILE_VERSION_GGJT_V3) {
if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 ||
hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) {
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)");
}
}
if (vocab_only) { if (vocab_only) {
return; return;
} }