convert-hf : support direct Q8_0 conversion (#7234)
* convert-hf : support q8_0 conversion * convert-hf : add missing ftype This was messing with the checksums otherwise. * convert-hf : add missing ftype to Baichuan and Xverse I didn't notice these on my first pass.
This commit is contained in:
parent
614d3b914e
commit
ee52225067
5 changed files with 169 additions and 58 deletions
|
@ -240,23 +240,6 @@ class Model:
|
|||
return False
|
||||
|
||||
def write_tensors(self):
|
||||
# same as ggml_compute_fp32_to_bf16 in ggml-impl.h
|
||||
def np_fp32_to_bf16(n: np.ndarray):
|
||||
# force nan to quiet
|
||||
n = np.where((n & 0x7fffffff) > 0x7f800000, (n & 0xffff0000) | (64 << 16), n)
|
||||
# flush subnormals to zero
|
||||
n = np.where((n & 0x7f800000) == 0, n & 0x80000000, n)
|
||||
# round to nearest even
|
||||
n = (n + (0x7fff + ((n >> 16) & 1))) >> 16
|
||||
return n.astype(np.int16)
|
||||
|
||||
# Doing this row-wise is much, much faster than element-wise, hence the signature
|
||||
v_fp32_to_bf16 = np.vectorize(np_fp32_to_bf16, otypes=[np.int16], signature="(n)->(n)")
|
||||
if self.lazy:
|
||||
# TODO: find a way to implicitly wrap np.vectorize functions
|
||||
# NOTE: the type is changed to reflect otypes passed to np.vectorize above
|
||||
v_fp32_to_bf16 = gguf.LazyNumpyTensor._wrap_fn(v_fp32_to_bf16, meta_noop=np.int16)
|
||||
|
||||
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
||||
|
||||
for name, data_torch in self.get_tensors():
|
||||
|
@ -309,27 +292,31 @@ class Model:
|
|||
))
|
||||
|
||||
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
|
||||
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
|
||||
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
data = gguf.quantize_bf16(data)
|
||||
assert data.dtype == np.int16
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
|
||||
data = gguf.quantize_q8_0(data)
|
||||
assert data.dtype == np.uint8
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
|
||||
else: # default to float16 for quantized tensors
|
||||
if data_dtype != np.float16:
|
||||
data = data.astype(np.float16)
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
if data_dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
data = v_fp32_to_bf16(data.view(np.int32))
|
||||
assert data.dtype == np.int16
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
|
||||
else: # by default, convert to float32
|
||||
if data_qtype is None: # by default, convert to float32
|
||||
if data_dtype != np.float32:
|
||||
data = data.astype(np.float32)
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
assert data_qtype is not None
|
||||
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[data_qtype]
|
||||
# reverse shape to make it similar to the internal ggml dimension order
|
||||
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
|
||||
shape_str = f"""{{{', '.join(str(n) for n in reversed(
|
||||
(*data.shape[:-1], data.shape[-1] * data.dtype.itemsize // type_size * block_size))
|
||||
)}}}"""
|
||||
|
||||
# n_dims is implicit in the shape
|
||||
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
|
||||
|
@ -859,6 +846,7 @@ class BaichuanModel(Model):
|
|||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||
|
@ -981,6 +969,7 @@ class XverseModel(Model):
|
|||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||
|
@ -1215,6 +1204,7 @@ class StableLMModel(Model):
|
|||
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
|
||||
self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"]))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
_q_norms: list[dict[str, Tensor]] | None = None
|
||||
_k_norms: list[dict[str, Tensor]] | None = None
|
||||
|
@ -1591,6 +1581,7 @@ class QwenModel(Model):
|
|||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
||||
@Model.register("Qwen2ForCausalLM")
|
||||
|
@ -1828,6 +1819,7 @@ class PlamoModel(Model):
|
|||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def shuffle_attn_q_weight(self, data_torch):
|
||||
assert data_torch.size() == (5120, 5120)
|
||||
|
@ -2007,6 +1999,7 @@ in chat mode so that the conversation can end normally.")
|
|||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
num_heads = self.hparams["num_attention_heads"]
|
||||
|
@ -2415,25 +2408,15 @@ class LazyTorchTensor(gguf.LazyBase):
|
|||
def numpy(self) -> gguf.LazyNumpyTensor:
|
||||
dtype = self._dtype_map[self.dtype]
|
||||
return gguf.LazyNumpyTensor(
|
||||
meta=np.lib.stride_tricks.as_strided(np.zeros(1, dtype), self.shape, (0 for _ in self.shape)),
|
||||
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
|
||||
lazy=self._lazy,
|
||||
args=(self,),
|
||||
func=(lambda s: s[0].numpy())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def eager_to_meta(cls, t: Tensor) -> Tensor:
|
||||
if t.is_meta:
|
||||
return t
|
||||
return t.detach().to("meta")
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype(cls, m: Tensor, dtype: torch.dtype) -> Tensor:
|
||||
m = m.detach()
|
||||
if not m.is_meta:
|
||||
m = m.to("meta")
|
||||
m.dtype = dtype
|
||||
return m
|
||||
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: torch.Size) -> Tensor:
|
||||
return torch.empty(size=shape, dtype=dtype, device="meta")
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
|
@ -2464,8 +2447,8 @@ def parse_args() -> argparse.Namespace:
|
|||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "auto"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian", action="store_true",
|
||||
|
@ -2523,6 +2506,7 @@ def main() -> None:
|
|||
"f32": gguf.LlamaFileType.ALL_F32,
|
||||
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||
"auto": gguf.LlamaFileType.GUESSED,
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue