Merge branch 'master' into compilade/convert-hf-refactor

This commit is contained in:
Francis Couture-Harpin 2024-05-03 16:20:54 -04:00
commit 3e5e0dced5
24 changed files with 480 additions and 399 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import logging
import argparse
import contextlib
import json
@ -25,6 +26,8 @@ import gguf
from convert import LlamaHfVocab
logger = logging.getLogger("hf-to-gguf")
###### MODEL DEFINITIONS ######
@ -99,7 +102,7 @@ class Model:
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for part_name in self.part_names:
print(f"gguf: loading model part '{part_name}'")
logger.info(f"gguf: loading model part '{part_name}'")
ctx: ContextManager[Any]
if self.is_safetensors:
from safetensors import safe_open
@ -115,8 +118,7 @@ class Model:
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
name: str = gguf.TENSOR_NAMES[key]
if key not in gguf.MODEL_TENSORS[self.model_arch]:
print(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
sys.exit()
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
if "{bid}" in name:
assert bid is not None
name = name.format(bid=bid)
@ -125,8 +127,7 @@ class Model:
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
raise ValueError(f"Can not map tensor {name!r}")
return new_name
def set_gguf_parameters(self):
@ -135,42 +136,42 @@ class Model:
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
self.gguf_writer.add_context_length(n_ctx)
print(f"gguf: context length = {n_ctx}")
logger.info(f"gguf: context length = {n_ctx}")
n_embd = self.find_hparam(["hidden_size", "n_embd"])
self.gguf_writer.add_embedding_length(n_embd)
print(f"gguf: embedding length = {n_embd}")
logger.info(f"gguf: embedding length = {n_embd}")
if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
self.gguf_writer.add_feed_forward_length(n_ff)
print(f"gguf: feed forward length = {n_ff}")
logger.info(f"gguf: feed forward length = {n_ff}")
n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_head_count(n_head)
print(f"gguf: head count = {n_head}")
logger.info(f"gguf: head count = {n_head}")
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)
print(f"gguf: key-value head count = {n_head_kv}")
logger.info(f"gguf: key-value head count = {n_head_kv}")
if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
print(f"gguf: rope theta = {rope_theta}")
logger.info(f"gguf: rope theta = {rope_theta}")
if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
print(f"gguf: rms norm epsilon = {f_rms_eps}")
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
print(f"gguf: layer norm epsilon = {f_norm_eps}")
logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
if (n_experts := self.hparams.get("num_local_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
print(f"gguf: expert count = {n_experts}")
logger.info(f"gguf: expert count = {n_experts}")
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)
print(f"gguf: experts used count = {n_experts_used}")
logger.info(f"gguf: experts used count = {n_experts_used}")
self.gguf_writer.add_file_type(self.ftype)
print(f"gguf: file type = {self.ftype}")
logger.info(f"gguf: file type = {self.ftype}")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
@ -236,7 +237,7 @@ class Model:
shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}"
# n_dims is implicit in the shape
print(f"{new_name}, shape = {shape_str}, {old_dtype} --> {data.dtype}")
logger.info(f"{new_name}, shape = {shape_str}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
@ -330,8 +331,8 @@ class Model:
chktok = tokenizer.encode(chktxt)
chkhsh = sha256(str(chktok).encode()).hexdigest()
print(f"chktok: {chktok}")
print(f"chkhsh: {chkhsh}")
logger.debug(f"chktok: {chktok}")
logger.debug(f"chkhsh: {chkhsh}")
res = None
@ -364,22 +365,22 @@ class Model:
res = "gpt-2"
if res is None:
print("\n")
print("**************************************************************************************")
print("** WARNING: The BPE pre-tokenizer was not recognized!")
print("** There are 2 possible reasons for this:")
print("** - the model has not been added to convert-hf-to-gguf-update.py yet")
print("** - the pre-tokenization config has changed upstream")
print("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.")
print("** ref: https://github.com/ggerganov/llama.cpp/pull/6920")
print("**")
print(f"** chkhsh: {chkhsh}")
print("**************************************************************************************")
print("\n")
logger.warning("\n")
logger.warning("**************************************************************************************")
logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!")
logger.warning("** There are 2 possible reasons for this:")
logger.warning("** - the model has not been added to convert-hf-to-gguf-update.py yet")
logger.warning("** - the pre-tokenization config has changed upstream")
logger.warning("** Check your model files and convert-hf-to-gguf-update.py and update them accordingly.")
logger.warning("** ref: https://github.com/ggerganov/llama.cpp/pull/6920")
logger.warning("**")
logger.warning(f"** chkhsh: {chkhsh}")
logger.warning("**************************************************************************************")
logger.warning("\n")
raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()")
print(f"tokenizer.ggml.pre: {res}")
print(f"chkhsh: {chkhsh}")
logger.debug(f"tokenizer.ggml.pre: {res}")
logger.debug(f"chkhsh: {chkhsh}")
return res
@ -497,9 +498,7 @@ class Model:
if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
print(
f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]"
)
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
for i in range(1, pad_count + 1):
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
scores.append(-1000.0)
@ -599,7 +598,7 @@ class BloomModel(Model):
),
dim=0,
)
print("re-format attention.linear_qkv.weight")
logger.info("re-format attention.linear_qkv.weight")
elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head))
data_torch = torch.cat(
@ -610,7 +609,7 @@ class BloomModel(Model):
),
dim=0,
)
print("re-format attention.linear_qkv.bias")
logger.info("re-format attention.linear_qkv.bias")
tensors.append((self.map_tensor_name(name), data_torch))
@ -688,8 +687,7 @@ class OrionModel(Model):
elif "model_max_length" in self.hparams:
ctx_length = self.hparams["model_max_length"]
else:
print("gguf: can not find ctx length parameter.")
sys.exit()
raise ValueError("gguf: can not find ctx length parameter.")
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_name(self.dir_model.name)
@ -727,8 +725,7 @@ class BaichuanModel(Model):
elif "model_max_length" in self.hparams:
ctx_length = self.hparams["model_max_length"]
else:
print("gguf: can not find ctx length parameter.")
sys.exit()
raise ValueError("gguf: can not find ctx length parameter.")
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_source_hf_repo(hf_repo)
@ -754,7 +751,7 @@ class BaichuanModel(Model):
tensors: list[tuple[str, Tensor]] = []
if bid is not None and name == f"model.layers.{bid}.self_attn.W_pack.weight":
print(f"Unpacking and permuting layer {bid}")
logger.info(f"Unpacking and permuting layer {bid}")
tensors = [
(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid),
self._reverse_hf_permute_part(data_torch, 0, head_count, head_count)),
@ -850,8 +847,7 @@ class XverseModel(Model):
elif "model_max_length" in self.hparams:
ctx_length = self.hparams["model_max_length"]
else:
print("gguf: can not find ctx length parameter.")
sys.exit()
raise ValueError("gguf: can not find ctx length parameter.")
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_source_hf_repo(hf_repo)
@ -1334,7 +1330,7 @@ class DbrxModel(Model):
self.gguf_writer.add_layer_norm_eps(1e-5)
self.gguf_writer.add_file_type(self.ftype)
print(f"gguf: file type = {self.ftype}")
logger.info(f"gguf: file type = {self.ftype}")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
@ -1604,8 +1600,7 @@ class Phi3MiniModel(Model):
tokenizer_path = self.dir_model / 'tokenizer.model'
if not tokenizer_path.is_file():
print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
sys.exit(1)
raise ValueError(f'Error: Missing {tokenizer_path}')
tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))
@ -1644,7 +1639,7 @@ class Phi3MiniModel(Model):
for key in added_tokens_json:
token_id = added_tokens_json[key]
if (token_id >= vocab_size):
print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue
tokens[token_id] = key.encode("utf-8")
@ -1782,7 +1777,7 @@ class InternLM2Model(Model):
toktypes: list[int] = []
if not tokenizer_path.is_file():
print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
logger.error(f'Error: Missing {tokenizer_path}')
sys.exit(1)
sentencepiece_model = model.ModelProto()
@ -1801,7 +1796,7 @@ class InternLM2Model(Model):
if text == b"\x00":
# (TODO): fixme
# Hack here and replace the \x00 characters.
print(f"InternLM2 convert token '{text}' to '🐉'!")
logger.warning(f"InternLM2 convert token '{text}' to '🐉'!")
text = "🐉".encode("utf-8")
toktype = SentencePieceTokenTypes.NORMAL
@ -1842,7 +1837,7 @@ class InternLM2Model(Model):
# TODO: this is a hack, should be fixed
# https://github.com/ggerganov/llama.cpp/pull/6745#issuecomment-2067687048
special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
logger.warning(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
in chat mode so that the conversation can end normally.")
special_vocab.add_to_gguf(self.gguf_writer)
@ -2052,7 +2047,7 @@ class GemmaModel(Model):
# lm_head is not used in llama.cpp, while autoawq will include this tensor in model
# To prevent errors, skip loading lm_head.weight.
if name == "lm_head.weight":
print(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
return []
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
@ -2087,7 +2082,7 @@ class MambaModel(Model):
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
print(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
neox_reader = gguf.GGUFReader(tokenizer_path, "r")
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
@ -2158,13 +2153,13 @@ class MambaModel(Model):
new_name = self.map_tensor_name(name)
if name.endswith(".A_log"):
print("A_log --> A ==> " + new_name)
logger.debug("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)
# assuming token_embd.weight is seen before output.weight
if self._tok_embd is not None and new_name == output_name:
if torch.equal(self._tok_embd, data_torch):
print(f"{output_name} is equivalent to {tok_embd_name}, omitting")
logger.debug(f"{output_name} is equivalent to {tok_embd_name}, omitting")
return []
elif new_name == tok_embd_name:
self._tok_embd = data_torch
@ -2257,6 +2252,7 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
return parser.parse_args()
@ -2264,6 +2260,8 @@ def parse_args() -> argparse.Namespace:
def main() -> None:
args = parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
dir_model = args.model
if args.awq_path:
@ -2272,15 +2270,15 @@ def main() -> None:
tmp_model_path = args.model / "weighted_model"
dir_model = tmp_model_path
if tmp_model_path.is_dir():
print(f"{tmp_model_path} exists as a weighted model.")
logger.info(f"{tmp_model_path} exists as a weighted model.")
else:
tmp_model_path.mkdir(parents=True, exist_ok=True)
print("Saving new weighted model ...")
logger.info("Saving new weighted model ...")
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
print(f"Saved weighted model at {tmp_model_path}.")
logger.info(f"Saved weighted model at {tmp_model_path}.")
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file=sys.stderr)
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
ftype_map = {
@ -2294,7 +2292,7 @@ def main() -> None:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{args.outtype}.gguf'
print(f"Loading model: {dir_model.name}")
logger.info(f"Loading model: {dir_model.name}")
hparams = Model.load_hparams(dir_model)
@ -2302,20 +2300,20 @@ def main() -> None:
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian, args.use_temp_file)
print("Set model parameters")
logger.info("Set model parameters")
model_instance.set_gguf_parameters()
print("Set model tokenizer")
logger.info("Set model tokenizer")
model_instance.set_vocab()
if args.vocab_only:
print(f"Exporting model vocab to '{fname_out}'")
logger.info(f"Exporting model vocab to '{fname_out}'")
model_instance.write_vocab()
else:
print(f"Exporting model to '{fname_out}'")
logger.info(f"Exporting model to '{fname_out}'")
model_instance.write()
print(f"Model successfully exported to '{fname_out}'")
logger.info(f"Model successfully exported to '{fname_out}'")
if __name__ == '__main__':