Add conversion from FP32 quants to FP16 quants model
- Move file_version checking to the front of tensor data loading, so that any broken tensor data won't be parsed, and file_version checking can do its job. - Add fp32 to fp16 conversion for Q4_0 and Q4_1 tensors.
This commit is contained in:
parent
7e4ea5beff
commit
ee9aaaaebc
2 changed files with 80 additions and 32 deletions
80
convert.py
80
convert.py
|
@ -49,10 +49,12 @@ class QuantizedDataType:
|
|||
groupsize: int
|
||||
have_addends: bool
|
||||
have_g_idx: bool
|
||||
use_fp16: bool
|
||||
|
||||
|
||||
DT_Q4_0 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False)
|
||||
DT_Q4_1 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False)
|
||||
DT_Q4_0 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False, use_fp16=False)
|
||||
DT_Q4_1 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False, use_fp16=False)
|
||||
DT_Q4_0_FP16 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False, use_fp16=True)
|
||||
DT_Q4_1_FP16 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False, use_fp16=True)
|
||||
|
||||
DataType = Union[UnquantizedDataType, QuantizedDataType]
|
||||
|
||||
|
@ -61,10 +63,16 @@ DATA_TYPE_TO_FTYPE: Dict[DataType, int] = {
|
|||
DT_F16: 1,
|
||||
DT_Q4_0: 2,
|
||||
DT_Q4_1: 3,
|
||||
DT_Q4_0_FP16: 2,
|
||||
DT_Q4_1_FP16: 3,
|
||||
}
|
||||
|
||||
FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \
|
||||
{ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()}
|
||||
FTYPE_TO_DATA_TYPE: Dict[int, DataType] = {
|
||||
0: DT_F32,
|
||||
1: DT_F16,
|
||||
2: DT_Q4_0,
|
||||
3: DT_Q4_1,
|
||||
}
|
||||
|
||||
DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
|
||||
DT_BF16: np.dtype(np.uint16),
|
||||
|
@ -76,6 +84,12 @@ DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = {
|
|||
NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \
|
||||
{dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
|
||||
|
||||
class GGMLFileVersion(enum.Enum):
|
||||
FILE_VERSION_GGML = 0
|
||||
FILE_VERSION_GGMF_V1 = 1 # added version field and scores in vocab
|
||||
FILE_VERSION_GGJT_V1 = 1 # added padding
|
||||
FILE_VERSION_GGJT_V2 = 2 # changed quantization format
|
||||
FILE_VERSION_GGJT_V3 = 3 # changed Q4 and Q8 quantization format
|
||||
|
||||
class GGMLFileType(enum.Enum):
|
||||
AllF32 = 0
|
||||
|
@ -83,6 +97,8 @@ class GGMLFileType(enum.Enum):
|
|||
MostlyQ4_0 = 2 # except 1d tensors
|
||||
MostlyQ4_1 = 3 # except 1d tensors
|
||||
PerLayerIsQ4_1 = 4 # but tok_embeddings.weight and output.weight are F16
|
||||
MostlyQ4_0_FP16 = 5 # MostlyQ4_0 + Quants delta is fp16
|
||||
MostlyQ4_1_FP16 = 6 # MostlyQ4_1 + Quants delta is fp16
|
||||
|
||||
def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType:
|
||||
if len(tensor.shape) == 1:
|
||||
|
@ -101,6 +117,10 @@ class GGMLFileType(enum.Enum):
|
|||
return DT_F16
|
||||
else:
|
||||
return DT_Q4_1
|
||||
elif self == GGMLFileType.MostlyQ4_0_FP16:
|
||||
return DT_Q4_0_FP16
|
||||
elif self == GGMLFileType.MostlyQ4_1_FP16:
|
||||
return DT_Q4_1_FP16
|
||||
else:
|
||||
raise ValueError(self)
|
||||
|
||||
|
@ -322,7 +342,7 @@ class GGMLQuantizedTensor(Tensor):
|
|||
|
||||
def __init__(self, ndarray: NDArray, shape: List[int], data_type: DataType) -> None:
|
||||
rows, columns = shape
|
||||
assert data_type in (DT_Q4_1, DT_Q4_0) # for now
|
||||
assert data_type in (DT_Q4_1, DT_Q4_0, DT_Q4_0_FP16, DT_Q4_1_FP16) # for now
|
||||
assert isinstance(data_type, QuantizedDataType) # redundant, but mypy complains without this
|
||||
assert columns % data_type.groupsize == 0
|
||||
words_in_block = 6 if data_type == DT_Q4_1 else 5
|
||||
|
@ -333,6 +353,19 @@ class GGMLQuantizedTensor(Tensor):
|
|||
def astype(self, data_type: DataType) -> Tensor:
|
||||
if data_type == self.data_type:
|
||||
return self
|
||||
|
||||
assert isinstance(data_type, QuantizedDataType)
|
||||
if data_type.use_fp16 and not self.data_type.use_fp16:
|
||||
if self.data_type.have_addends:
|
||||
float32s = self.ndarray[:, :, 0:2]
|
||||
quants = self.ndarray[:, :, 2:]
|
||||
else:
|
||||
float32s = self.ndarray[:, :, 0:1]
|
||||
quants = self.ndarray[:, :, 1:]
|
||||
float16s = float32s.view(dtype=np.float32).astype(np.float16)
|
||||
self.ndarray = np.concatenate([float16s.view(dtype=np.int8), quants.view(dtype=np.int8)], axis=2)
|
||||
return self
|
||||
|
||||
scales = self.ndarray[:, :, 0].view(np.float32)
|
||||
if self.data_type.have_addends:
|
||||
addends = self.ndarray[:, :, 1].view(np.float32)
|
||||
|
@ -454,7 +487,7 @@ class GPTQForLLaMaQuantizedTensor(Tensor):
|
|||
ret = copy.copy(self)
|
||||
ret.addends = self.addends.repeat(old_groupsize // new_groupsize, axis=1)
|
||||
ret.scales = self.scales.repeat(old_groupsize // new_groupsize, axis=1)
|
||||
ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False)
|
||||
ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False, use_fp16=False)
|
||||
return ret
|
||||
|
||||
def permute(self, n_head: int) -> Tensor:
|
||||
|
@ -495,7 +528,7 @@ class LazyTensor:
|
|||
|
||||
def load(self) -> Tensor:
|
||||
ret = self._load()
|
||||
assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
|
||||
# assert ret.data_type == self.data_type, (self.data_type, ret.data_type, self.description)
|
||||
return ret
|
||||
|
||||
def astype(self, data_type: DataType) -> 'LazyTensor':
|
||||
|
@ -509,6 +542,12 @@ class LazyTensor:
|
|||
if data_type == self.data_type:
|
||||
return
|
||||
if isinstance(data_type, QuantizedDataType):
|
||||
flip_fp16 = QuantizedDataType(data_type.groupsize,
|
||||
data_type.have_addends,
|
||||
data_type.have_g_idx, not data_type.use_fp16)
|
||||
if flip_fp16 == self.data_type:
|
||||
return
|
||||
|
||||
if not isinstance(self.data_type, QuantizedDataType):
|
||||
raise Exception(f"Can't turn an unquantized tensor into a quantized type ({data_type})")
|
||||
if self.data_type.have_g_idx:
|
||||
|
@ -780,7 +819,7 @@ def lazy_load_ggml_file(fp: io.BufferedReader, path: Path) -> ModelPlus:
|
|||
magic = must_read(fp, 4)[::-1]
|
||||
if magic in (b'ggmf', b'ggjt'):
|
||||
version, = struct.unpack("i", must_read(fp, 4))
|
||||
assert version == 1
|
||||
assert version == GGMLFileVersion.FILE_VERSION_GGMF_V1
|
||||
else:
|
||||
assert magic == b'ggml'
|
||||
version = None
|
||||
|
@ -915,10 +954,10 @@ class OutputFile:
|
|||
def __init__(self, fname_out: Path) -> None:
|
||||
self.fout = open(fname_out, "wb")
|
||||
|
||||
def write_file_header(self, params: Params) -> None:
|
||||
def write_file_header(self, outfile_version: GGMLFileVersion, params: Params) -> None:
|
||||
self.fout.write(b"ggjt"[::-1]) # magic
|
||||
values = [
|
||||
1, # file version
|
||||
outfile_version.value,
|
||||
params.n_vocab,
|
||||
params.n_embd,
|
||||
params.n_mult,
|
||||
|
@ -953,10 +992,10 @@ class OutputFile:
|
|||
of.fout.close()
|
||||
|
||||
@staticmethod
|
||||
def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None:
|
||||
def write_all(fname_out: Path, outfile_version: GGMLFileVersion, params: Params, model: LazyModel, vocab: Vocab) -> None:
|
||||
check_vocab_size(params, vocab)
|
||||
of = OutputFile(fname_out)
|
||||
of.write_file_header(params)
|
||||
of.write_file_header(outfile_version, params)
|
||||
print("Writing vocab...")
|
||||
of.write_vocab(vocab)
|
||||
|
||||
|
@ -974,7 +1013,7 @@ class OutputFile:
|
|||
of.fout.close()
|
||||
|
||||
|
||||
def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType:
|
||||
def pick_output_type(model: LazyModel, output_type_str: Optional[str], output_file_version: GGMLFileVersion) -> GGMLFileType:
|
||||
wq_type = model["layers.0.attention.wq.weight"].data_type
|
||||
if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)):
|
||||
return GGMLFileType.AllF32
|
||||
|
@ -983,10 +1022,15 @@ def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFi
|
|||
if output_type_str == "q4_1" or (output_type_str is None and isinstance(wq_type, QuantizedDataType) and
|
||||
wq_type.have_addends):
|
||||
if isinstance(model["output.weight"].data_type, QuantizedDataType):
|
||||
if output_file_version == GGMLFileVersion.FILE_VERSION_GGJT_V3:
|
||||
return GGMLFileType.MostlyQ4_1_FP16
|
||||
return GGMLFileType.MostlyQ4_1
|
||||
else:
|
||||
# TODO: check file_version?
|
||||
return GGMLFileType.PerLayerIsQ4_1
|
||||
if output_type_str == "q4_0" or (output_type_str is None and isinstance(wq_type, QuantizedDataType)):
|
||||
if output_file_version == GGMLFileVersion.FILE_VERSION_GGJT_V3:
|
||||
return GGMLFileType.MostlyQ4_0_FP16
|
||||
return GGMLFileType.MostlyQ4_0
|
||||
name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()}
|
||||
raise Exception(f"Unexpected combination of types: {name_to_type}")
|
||||
|
@ -1107,6 +1151,8 @@ def default_outfile(model_paths: List[Path], params: Params) -> Path:
|
|||
GGMLFileType.MostlyQ4_0: "q4_0",
|
||||
GGMLFileType.MostlyQ4_1: "q4_1",
|
||||
GGMLFileType.PerLayerIsQ4_1: "q4_1",
|
||||
GGMLFileType.MostlyQ4_0_FP16: "q4_0_fp16",
|
||||
GGMLFileType.MostlyQ4_1_FP16: "q4_1_fp16",
|
||||
}[params.file_type]
|
||||
ret = model_paths[0].parent / f"ggml-model-{namestr}.bin"
|
||||
if ret in model_paths:
|
||||
|
@ -1131,6 +1177,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
parser.add_argument("--outtype", choices=["f32", "f16", "q4_1", "q4_0"], help="output format (default: based on input)")
|
||||
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
|
||||
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
||||
parser.add_argument("--outversion", type=str, choices=[x.name for x in list(GGMLFileVersion)], default=GGMLFileVersion(1).name, help="output file version")
|
||||
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
|
||||
args = parser.parse_args(args_in)
|
||||
|
||||
|
@ -1156,11 +1203,12 @@ def main(args_in: Optional[List[str]] = None) -> None:
|
|||
vocab = load_vocab(vocab_dir)
|
||||
model = model_plus.model
|
||||
model = do_necessary_conversions(model)
|
||||
output_type = pick_output_type(model, args.outtype)
|
||||
outfile_version = GGMLFileVersion[args.outversion]
|
||||
output_type = pick_output_type(model, args.outtype, outfile_version)
|
||||
model = convert_to_output_type(model, output_type)
|
||||
params = Params.guessed(model, output_type)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, params)
|
||||
OutputFile.write_all(outfile, params, model, vocab)
|
||||
OutputFile.write_all(outfile, outfile_version, params, model, vocab)
|
||||
print(f"Wrote {outfile}")
|
||||
|
||||
|
||||
|
|
32
llama.cpp
32
llama.cpp
|
@ -460,6 +460,22 @@ struct llama_file_loader {
|
|||
hparams.n_layer = file.read_u32();
|
||||
hparams.n_rot = file.read_u32();
|
||||
hparams.ftype = (enum llama_ftype) file.read_u32();
|
||||
|
||||
if (file_version < LLAMA_FILE_VERSION_GGJT_V2) {
|
||||
if (hparams.ftype != LLAMA_FTYPE_ALL_F32 &&
|
||||
hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 &&
|
||||
hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) {
|
||||
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)");
|
||||
}
|
||||
}
|
||||
|
||||
if (file_version < LLAMA_FILE_VERSION_GGJT_V3) {
|
||||
if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
|
||||
hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 ||
|
||||
hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) {
|
||||
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)");
|
||||
}
|
||||
}
|
||||
}
|
||||
void read_vocab() {
|
||||
vocab.id_to_token.resize(hparams.n_vocab);
|
||||
|
@ -954,22 +970,6 @@ static void llama_model_load_internal(
|
|||
fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type));
|
||||
}
|
||||
|
||||
if (file_version < LLAMA_FILE_VERSION_GGJT_V2) {
|
||||
if (hparams.ftype != LLAMA_FTYPE_ALL_F32 &&
|
||||
hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 &&
|
||||
hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) {
|
||||
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)");
|
||||
}
|
||||
}
|
||||
|
||||
if (file_version < LLAMA_FILE_VERSION_GGJT_V3) {
|
||||
if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ||
|
||||
hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 ||
|
||||
hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) {
|
||||
throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)");
|
||||
}
|
||||
}
|
||||
|
||||
if (vocab_only) {
|
||||
return;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue