diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index e1ac09e02..b037c50cb 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -11,7 +11,16 @@ import sys from abc import ABC, abstractmethod from enum import IntEnum from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterator, Sequence, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Iterator, + Sequence, + TypeVar, + cast, +) import numpy as np import torch @@ -19,15 +28,15 @@ import torch if TYPE_CHECKING: from torch import Tensor -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py")) import gguf from convert import LlamaHfVocab, permute - ###### MODEL DEFINITIONS ###### + class SentencePieceTokenTypes(IntEnum): NORMAL = 1 UNKNOWN = 2 @@ -43,18 +52,31 @@ AnyModel = TypeVar("AnyModel", bound="type[Model]") class Model(ABC): _model_classes: dict[str, type[Model]] = {} - def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool): + def __init__( + self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian: bool + ): self.dir_model = dir_model self.ftype = ftype self.fname_out = fname_out self.is_big_endian = is_big_endian - self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.endianess = ( + gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + ) self.is_safetensors = self._is_model_safetensors() - self.num_parts = Model.count_model_parts(self.dir_model, ".safetensors" if self.is_safetensors else ".bin") + self.num_parts = Model.count_model_parts( + self.dir_model, ".safetensors" if self.is_safetensors else ".bin" + ) self.part_names = self._get_part_names() self.hparams = Model.load_hparams(self.dir_model) - self.gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=False) - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"]) + self.gguf_writer = gguf.GGUFWriter( + fname_out, + gguf.MODEL_ARCH_NAMES[self.model_arch], + endianess=self.endianess, + use_temp_file=False, + ) + self.block_count = self.find_hparam( + ["n_layers", "num_hidden_layers", "n_layer"] + ) @property @abstractmethod @@ -78,20 +100,39 @@ class Model(ABC): ctx: ContextManager[Any] if self.is_safetensors: from safetensors import safe_open - ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) + + ctx = cast( + ContextManager[Any], + safe_open(self.dir_model / part_name, framework="pt", device="cpu"), + ) else: - ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) + ctx = contextlib.nullcontext( + torch.load( + str(self.dir_model / part_name), + map_location="cpu", + mmap=True, + weights_only=True, + ) + ) with ctx as model_part: for name in model_part.keys(): - data = model_part.get_tensor(name) if self.is_safetensors else model_part[name] + data = ( + model_part.get_tensor(name) + if self.is_safetensors + else model_part[name] + ) yield name, data def set_gguf_parameters(self): self.gguf_writer.add_name(self.dir_model.name) self.gguf_writer.add_block_count(self.block_count) - if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: + if ( + n_ctx := self.find_hparam( + ["max_position_embeddings", "n_ctx"], optional=True + ) + ) is not None: self.gguf_writer.add_context_length(n_ctx) print(f"gguf: context length = {n_ctx}") @@ -99,7 +140,9 @@ class Model(ABC): self.gguf_writer.add_embedding_length(n_embd) print(f"gguf: embedding length = {n_embd}") - if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None: + if ( + n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True) + ) is not None: self.gguf_writer.add_feed_forward_length(n_ff) print(f"gguf: feed forward length = {n_ff}") @@ -117,7 +160,11 @@ class Model(ABC): if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) print(f"gguf: rms norm epsilon = {f_rms_eps}") - if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: + if ( + f_norm_eps := self.find_hparam( + ["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True + ) + ) is not None: self.gguf_writer.add_layer_norm_eps(f_norm_eps) print(f"gguf: layer norm epsilon = {f_norm_eps}") if (n_experts := self.hparams.get("num_local_experts")) is not None: @@ -131,11 +178,20 @@ class Model(ABC): print(f"gguf: file type = {self.ftype}") def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + block_count = self.hparams.get( + "n_layers", + self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")), + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) for name, data_torch in self.get_tensors(): # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + if name.endswith( + ( + ".attention.masked_bias", + ".attention.bias", + ".attention.rotary_emb.inv_freq", + ) + ): continue old_dtype = data_torch.dtype @@ -160,11 +216,21 @@ class Model(ABC): data = data.astype(np.float32) # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 - if self.ftype == 1 and data_dtype == np.float16 and (n_dims == 1 or new_name.endswith("_norm.weight")): + if ( + self.ftype == 1 + and data_dtype == np.float16 + and (n_dims == 1 or new_name.endswith("_norm.weight")) + ): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and not new_name.endswith("_norm.weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -205,6 +271,7 @@ class Model(ABC): for name in names: cls._model_classes[name] = modelcls return modelcls + return func @classmethod @@ -212,7 +279,7 @@ class Model(ABC): try: return cls._model_classes[arch] except KeyError: - raise NotImplementedError(f'Architecture {arch!r} not supported!') from None + raise NotImplementedError(f"Architecture {arch!r} not supported!") from None def _is_model_safetensors(self) -> bool: return Model.count_model_parts(self.dir_model, ".safetensors") > 0 @@ -221,11 +288,17 @@ class Model(ABC): if self.is_safetensors: if self.num_parts == 1: # there's only one .safetensors file return ("model.safetensors",) - return (f"model-{n:05}-of-{self.num_parts:05}.safetensors" for n in range(1, self.num_parts + 1)) + return ( + f"model-{n:05}-of-{self.num_parts:05}.safetensors" + for n in range(1, self.num_parts + 1) + ) if self.num_parts == 1: # there's only one .bin file return ("pytorch_model.bin",) - return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1)) + return ( + f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" + for n in range(1, self.num_parts + 1) + ) # used for GPT-2 BPE and WordPiece vocabs def get_basic_vocab(self) -> tuple[list[str], list[int]]: @@ -233,11 +306,14 @@ class Model(ABC): toktypes: list[int] = [] from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) assert max(tokenizer.vocab.values()) < vocab_size - reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + reverse_vocab = { + id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items() + } added_vocab = tokenizer.get_added_vocab() for i in range(vocab_size): @@ -272,6 +348,7 @@ class Model(ABC): toktypes: list[int] = [] from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) vocab_size = hparams["vocab_size"] assert max(tokenizer.get_vocab().values()) < vocab_size @@ -285,11 +362,13 @@ class Model(ABC): continue merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) assert len(merged) == 2 - merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + merges.append(" ".join(map(QwenModel.token_bytes_to_string, merged))) # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined added_vocab = tokenizer.special_tokens - reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items()} + reverse_vocab = { + id_: encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items() + } for i in range(vocab_size): if i not in reverse_vocab: @@ -310,16 +389,22 @@ class Model(ABC): special_vocab.merges = merges # only add special tokens when they were not already loaded from config.json if len(special_vocab.special_token_ids) == 0: - special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"]) - special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"]) + special_vocab._set_special_token( + "bos", tokenizer.special_tokens["<|endoftext|>"] + ) + special_vocab._set_special_token( + "eos", tokenizer.special_tokens["<|endoftext|>"] + ) # this one is usually not in config.json anyway - special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"]) + special_vocab._set_special_token( + "unk", tokenizer.special_tokens["<|endoftext|>"] + ) special_vocab.add_to_gguf(self.gguf_writer) def _set_vocab_sentencepiece(self): from sentencepiece import SentencePieceProcessor - tokenizer_path = self.dir_model / 'tokenizer.model' + tokenizer_path = self.dir_model / "tokenizer.model" tokens: list[bytes] = [] scores: list[float] = [] @@ -329,7 +414,7 @@ class Model(ABC): raise FileNotFoundError(f"File not found: {tokenizer_path}") tokenizer = SentencePieceProcessor(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size()) for token_id in range(tokenizer.vocab_size()): piece = tokenizer.id_to_piece(token_id) @@ -350,7 +435,7 @@ class Model(ABC): scores.append(score) toktypes.append(toktype) - added_tokens_file = self.dir_model / 'added_tokens.json' + added_tokens_file = self.dir_model / "added_tokens.json" if added_tokens_file.is_file(): with open(added_tokens_file, "r", encoding="utf-8") as f: added_tokens_json = json.load(f) @@ -407,10 +492,15 @@ class GPTNeoXModel(Model): self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count( - int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])), + int( + self.hparams["rotary_pct"] + * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + ), ) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) - self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) + self.gguf_writer.add_parallel_residual( + self.hparams.get("use_parallel_residual", True) + ) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) @@ -440,10 +530,13 @@ class BloomModel(Model): n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) for name, data_torch in tensors.items(): - if "lm_head.weight" not in tensors.keys() and "output.weight" not in tensors.keys(): + if ( + "lm_head.weight" not in tensors.keys() + and "output.weight" not in tensors.keys() + ): has_lm_head = False - name = re.sub(r'transformer\.', '', name) + name = re.sub(r"transformer\.", "", name) old_dtype = data_torch.dtype @@ -497,7 +590,12 @@ class BloomModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"=> {new_name}, shape = {data.shape}, {old_dtype} --> {data.dtype}") @@ -506,7 +604,10 @@ class BloomModel(Model): if not has_lm_head and name == "word_embeddings.weight": self.gguf_writer.add_tensor("output.weight", data) - print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") + print( + name, + f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}", + ) @Model.register("MPTForCausalLM") @@ -538,16 +639,26 @@ class MPTModel(Model): if self.hparams["attn_config"]["clip_qkv"] is not None: self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"]) if self.hparams["attn_config"]["alibi"]: - self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"]) + self.gguf_writer.add_max_alibi_bias( + self.hparams["attn_config"]["alibi_bias_max"] + ) else: self.gguf_writer.add_max_alibi_bias(0.0) def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers")) + block_count = self.hparams.get( + "n_layers", self.hparams.get("num_hidden_layers") + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) for name, data_torch in self.get_tensors(): # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + if name.endswith( + ( + ".attention.masked_bias", + ".attention.bias", + ".attention.rotary_emb.inv_freq", + ) + ): continue old_dtype = data_torch.dtype @@ -560,7 +671,9 @@ class MPTModel(Model): # map tensor names if "scales" in name: - new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales")) + new_name = tensor_map.get_name( + name, try_suffixes=(".weight", ".bias", ".scales") + ) if new_name is not None: new_name = new_name.replace("scales", "act.scales") else: @@ -581,7 +694,12 @@ class MPTModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -664,10 +782,17 @@ class OrionModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) - print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + print( + f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}" + ) self.gguf_writer.add_tensor(new_name, data) @@ -702,15 +827,22 @@ class BaichuanModel(Model): self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if ( + self.hparams.get("rope_scaling") is not None + and "factor" in self.hparams["rope_scaling"] + ): if self.hparams["rope_scaling"].get("type") == "linear": self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_rope_scaling_factor( + self.hparams["rope_scaling"]["factor"] + ) def write_tensors(self): # Collect tensors from generator object @@ -721,14 +853,19 @@ class BaichuanModel(Model): head_count_kv = self.hparams.get("num_key_value_heads", head_count) for i in range(block_count): - if (w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight")) is not None: + if ( + w := model_kv.get(f"model.layers.{i}.self_attn.W_pack.weight") + ) is not None: print(f"Unpacking and permuting layer {i}") - model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = \ - self._reverse_hf_permute_part(w, 0, head_count, head_count) - model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = \ - self._reverse_hf_permute_part(w, 1, head_count, head_count_kv) - model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = \ - self._reverse_hf_part(w, 2) + model_kv[ + f"model.layers.{i}.self_attn.q_proj.weight" + ] = self._reverse_hf_permute_part(w, 0, head_count, head_count) + model_kv[ + f"model.layers.{i}.self_attn.k_proj.weight" + ] = self._reverse_hf_permute_part(w, 1, head_count, head_count_kv) + model_kv[ + f"model.layers.{i}.self_attn.v_proj.weight" + ] = self._reverse_hf_part(w, 2) del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"] for name, data_torch in model_kv.items(): @@ -762,31 +899,48 @@ class BaichuanModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) - print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + print( + f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}" + ) self.gguf_writer.add_tensor(new_name, data) - def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + def _reverse_hf_permute( + self, weights: Tensor, n_head: int, n_kv_head: int | None = None + ) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head return ( - weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) .swapaxes(1, 2) .reshape(weights.shape) ) def _reverse_hf_permute_part( - self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None, + self, + weights: Tensor, + n_part: int, + n_head: int, + n_head_kv: int | None = None, ) -> Tensor: r = weights.shape[0] // 3 - return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv) + return self._reverse_hf_permute( + weights[r * n_part : r * n_part + r, ...], n_head, n_head_kv + ) def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor: r = weights.shape[0] // 3 - return weights[r * n_part:r * n_part + r, ...] + return weights[r * n_part : r * n_part + r, ...] @Model.register("XverseForCausalLM") @@ -802,20 +956,23 @@ class XverseModel(Model): toktypes: list[int] = [] from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model) vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) assert max(tokenizer.vocab.values()) < vocab_size - reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + reverse_vocab = { + id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items() + } added_vocab = tokenizer.get_added_vocab() for token_id in range(vocab_size): - token_text = reverse_vocab[token_id].encode('utf-8') + token_text = reverse_vocab[token_id].encode("utf-8") # replace "\x00" to string with length > 0 if token_text == b"\x00": toktype = gguf.TokenType.BYTE # special - token_text = f"<{token_text}>".encode('utf-8') - elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text): + token_text = f"<{token_text}>".encode("utf-8") + elif re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text): toktype = gguf.TokenType.BYTE # special elif reverse_vocab[token_id] in added_vocab: if tokenizer.added_tokens_decoder[token_id].special: @@ -859,15 +1016,22 @@ class XverseModel(Model): self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: + if ( + self.hparams.get("rope_scaling") is not None + and "factor" in self.hparams["rope_scaling"] + ): if self.hparams["rope_scaling"].get("type") == "linear": self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + self.gguf_writer.add_rope_scaling_factor( + self.hparams["rope_scaling"]["factor"] + ) def write_tensors(self): # Collect tensors from generator object @@ -890,9 +1054,13 @@ class XverseModel(Model): # HF models permute some of the tensors, so we need to undo that if name.endswith(("q_proj.weight")): - data_torch = self._reverse_hf_permute(data_torch, head_count, head_count) + data_torch = self._reverse_hf_permute( + data_torch, head_count, head_count + ) if name.endswith(("k_proj.weight")): - data_torch = self._reverse_hf_permute(data_torch, head_count, head_count_kv) + data_torch = self._reverse_hf_permute( + data_torch, head_count, head_count_kv + ) data = data_torch.squeeze().numpy() @@ -914,18 +1082,29 @@ class XverseModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) - print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + print( + f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}" + ) self.gguf_writer.add_tensor(new_name, data) - def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + def _reverse_hf_permute( + self, weights: Tensor, n_head: int, n_kv_head: int | None = None + ) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head return ( - weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) .swapaxes(1, 2) .reshape(weights.shape) ) @@ -993,7 +1172,9 @@ class FalconModel(Model): # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py if "query_key_value" in name: - qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head) + qkv = data_torch.view( + n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head + ) q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head) k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head) v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head) @@ -1019,7 +1200,12 @@ class FalconModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -1086,13 +1272,19 @@ class RefactModel(Model): tensors = dict(self.get_tensors()) for i in range(block_count): if (w := tensors.get(f"transformer.h.{i}.attn.kv.weight")) is not None: - tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[:n_head_kv * head_dim] - tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[n_head_kv * head_dim:] + tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[ + : n_head_kv * head_dim + ] + tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[ + n_head_kv * head_dim : + ] del tensors[f"transformer.h.{i}.attn.kv.weight"] if (w := tensors.get(f"transformer.h.{i}.attn.q.weight")) is not None: tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w del tensors[f"transformer.h.{i}.attn.q.weight"] - if (w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight")) is not None: + if ( + w := tensors.get(f"transformer.h.{i}.mlp.gate_up_proj.weight") + ) is not None: tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim] tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:] del tensors[f"transformer.h.{i}.mlp.gate_up_proj.weight"] @@ -1124,7 +1316,12 @@ class RefactModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -1137,12 +1334,14 @@ class PersimmonModel(Model): model_arch = gguf.MODEL_ARCH.PERSIMMON def set_gguf_parameters(self): - block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) + block_count = self.hparams.get( + "num_layers", self.hparams.get("num_hidden_layers") + ) head_count = self.hparams["num_attention_heads"] head_count_kv = head_count hidden_size = self.hparams["hidden_size"] - self.gguf_writer.add_name('persimmon-8b-chat') + self.gguf_writer.add_name("persimmon-8b-chat") self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_block_count(block_count) @@ -1165,7 +1364,9 @@ class PersimmonModel(Model): # self.gguf_writer.add_eos_token_id(71013) def write_tensors(self): - block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) + block_count = self.hparams.get( + "num_layers", self.hparams.get("num_hidden_layers") + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) for name, data_torch in self.get_tensors(): @@ -1183,18 +1384,35 @@ class PersimmonModel(Model): self.gguf_writer.add_tensor(new_name, data) -@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") -class StableLMModel(Model): +@Model.register( + "StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM" +) +class StableLM2Model(Model): model_arch = gguf.MODEL_ARCH.STABLELM + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + model_arch = ( + gguf.MODEL_ARCH.STABLELM + if self.hparams["num_hidden_layers"] < 40 + else gguf.MODEL_ARCH.STABLELM2 + ) + self.gguf_writer = gguf.GGUFWriter( + self.fname_out, + gguf.MODEL_ARCH_NAMES[model_arch], + endianess=self.endianess, + use_temp_file=False, + ) + def set_vocab(self): if (self.dir_model / "tokenizer.json").is_file(): self._set_vocab_gpt2() else: - # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab + # StableLM 2 uses a vocab in a similar format to Qwen's vocab self._set_vocab_qwen() def set_gguf_parameters(self): + super().set_gguf_parameters() hparams = self.hparams block_count = hparams["num_hidden_layers"] @@ -1204,10 +1422,22 @@ class StableLMModel(Model): self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) - self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) + self.gguf_writer.add_rope_dimension_count( + int( + rotary_factor + * (hparams["hidden_size"] // hparams["num_attention_heads"]) + ) + ) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) - self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True) - self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"])) + self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) + self.gguf_writer.add_parallel_residual( + hparams["use_parallel_residual"] + if "use_parallel_residual" in hparams + else True + ) + self.gguf_writer.add_layer_norm_eps( + self.find_hparam(["layer_norm_eps", "norm_eps"]) + ) @Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") @@ -1216,7 +1446,7 @@ class LlamaModel(Model): def set_vocab(self): try: - self. _set_vocab_sentencepiece() + self._set_vocab_sentencepiece() except FileNotFoundError: self._set_vocab_llama_hf() @@ -1224,11 +1454,16 @@ class LlamaModel(Model): super().set_gguf_parameters() hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count( + hparams["hidden_size"] // hparams["num_attention_heads"] + ) # Same as super class, but permuting q_proj, k_proj def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + block_count = self.hparams.get( + "n_layers", + self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")), + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) n_head = self.hparams.get("num_attention_heads") n_kv_head = self.hparams.get("num_key_value_heads") @@ -1236,7 +1471,13 @@ class LlamaModel(Model): experts = dict() for name, data_torch in self.get_tensors(): # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + if name.endswith( + ( + ".attention.masked_bias", + ".attention.bias", + ".attention.rotary_emb.inv_freq", + ) + ): continue old_dtype = data_torch.dtype @@ -1285,14 +1526,20 @@ class LlamaModel(Model): if self.ftype == 1 and data_dtype == np.float32: data = data.astype(np.float16) - merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight" + merged_name = ( + f"layers.{bid}.feed_forward.experts.w{wid}.weight" + ) - new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias")) + new_name = tensor_map.get_name( + merged_name, try_suffixes=(".weight", ".bias") + ) if new_name is None: print(f"Can not map tensor {name!r}") sys.exit() - print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}") + print( + f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}" + ) self.gguf_writer.add_tensor(new_name, data) continue @@ -1315,7 +1562,12 @@ class LlamaModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -1341,13 +1593,22 @@ class GrokModel(Model): self.gguf_writer.add_name("Grok") def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + block_count = self.hparams.get( + "n_layers", + self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")), + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) n_experts = self.hparams.get("num_local_experts") experts = dict() for name, data_torch in self.get_tensors(): # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + if name.endswith( + ( + ".attention.masked_bias", + ".attention.bias", + ".attention.rotary_emb.inv_freq", + ) + ): continue old_dtype = data_torch.dtype @@ -1389,14 +1650,20 @@ class GrokModel(Model): if self.ftype == 1 and data_dtype == np.float32: data = data.astype(np.float16) - merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight" + merged_name = ( + f"transformer.decoder_layer.{bid}.moe.{wid}.weight" + ) - new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias")) + new_name = tensor_map.get_name( + merged_name, try_suffixes=(".weight", ".bias") + ) if new_name is None: print(f"Can not map tensor {name!r}") sys.exit() - print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}") + print( + f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}" + ) self.gguf_writer.add_tensor(new_name, data) continue @@ -1419,7 +1686,12 @@ class GrokModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -1534,7 +1806,9 @@ class MiniCPMModel(Model): self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) @@ -1543,24 +1817,37 @@ class MiniCPMModel(Model): def set_vocab(self): self._set_vocab_llama_hf() - def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + def _reverse_hf_permute( + self, weights: Tensor, n_head: int, n_kv_head: int | None = None + ) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head return ( - weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) .swapaxes(1, 2) .reshape(weights.shape) ) def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + block_count = self.hparams.get( + "n_layers", + self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")), + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) n_head = self.hparams.get("num_attention_heads") n_kv_head = self.hparams.get("num_key_value_heads") for name, data_torch in self.get_tensors(): # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + if name.endswith( + ( + ".attention.masked_bias", + ".attention.bias", + ".attention.rotary_emb.inv_freq", + ) + ): continue old_dtype = data_torch.dtype @@ -1595,7 +1882,12 @@ class MiniCPMModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -1610,11 +1902,14 @@ class QwenModel(Model): @staticmethod def token_bytes_to_string(b): from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + byte_encoder = bytes_to_unicode() - return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) @staticmethod - def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + def bpe( + mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None + ) -> list[bytes]: parts = [bytes([b]) for b in token] while True: min_idx = None @@ -1627,7 +1922,11 @@ class QwenModel(Model): if min_rank is None or (max_rank is not None and min_rank >= max_rank): break assert min_idx is not None - parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + parts = ( + parts[:min_idx] + + [parts[min_idx] + parts[min_idx + 1]] + + parts[min_idx + 2 :] + ) return parts def set_vocab(self): @@ -1640,7 +1939,9 @@ class QwenModel(Model): self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) @@ -1679,14 +1980,19 @@ class QwenModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") self.gguf_writer.add_tensor(new_name, data) -@Model.register("Qwen2ForCausalLM") +@Model.register("wen2ForCausalLM") class Qwen2Model(Model): model_arch = gguf.MODEL_ARCH.QWEN2 @@ -1706,15 +2012,28 @@ class GPT2Model(Model): self.gguf_writer.add_file_type(self.ftype) def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + block_count = self.hparams.get( + "n_layers", + self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")), + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) for name, data_torch in self.get_tensors(): # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq", ".attn.bias", ".attn.masked_bias")): + if name.endswith( + ( + ".attention.masked_bias", + ".attention.bias", + ".attention.rotary_emb.inv_freq", + ".attn.bias", + ".attn.masked_bias", + ) + ): continue - if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")): + if name.endswith( + (".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight") + ): data_torch = data_torch.transpose(1, 0) old_dtype = data_torch.dtype @@ -1743,7 +2062,12 @@ class GPT2Model(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -1768,14 +2092,18 @@ class Phi2Model(Model): n_head = self.find_hparam(["num_attention_heads", "n_head"]) self.gguf_writer.add_name("Phi2") - self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"])) + self.gguf_writer.add_context_length( + self.find_hparam(["n_positions", "max_position_embeddings"]) + ) self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_feed_forward_length(4 * n_embd) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head) - self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"])) + self.gguf_writer.add_layer_norm_eps( + self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]) + ) self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_add_bos_token(False) @@ -1798,7 +2126,9 @@ class PlamoModel(Model): self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong + self.gguf_writer.add_head_count_kv( + 5 + ) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) def shuffle_attn_q_weight(self, data_torch): @@ -1816,7 +2146,9 @@ class PlamoModel(Model): return data_torch def write_tensors(self): - block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers")) + block_count = self.hparams.get( + "num_layers", self.hparams.get("num_hidden_layers") + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) for name, data_torch in self.get_tensors(): @@ -1855,7 +2187,12 @@ class PlamoModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -1884,10 +2221,15 @@ class CodeShellModel(Model): self.gguf_writer.add_rope_scaling_factor(1.0) def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + block_count = self.hparams.get( + "n_layers", + self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")), + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) tensors = dict(self.get_tensors()) - has_lm_head = "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys() + has_lm_head = ( + "lm_head.weight" in tensors.keys() or "output.weight" in tensors.keys() + ) for name, data_torch in tensors.items(): # we don't need these if name.endswith((".attn.rotary_emb.inv_freq")): @@ -1919,7 +2261,12 @@ class CodeShellModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -1928,7 +2275,10 @@ class CodeShellModel(Model): if not has_lm_head and name == "transformer.wte.weight": self.gguf_writer.add_tensor("output.weight", data) - print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}") + print( + name, + f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}", + ) @Model.register("InternLM2ForCausalLM") @@ -1943,14 +2293,14 @@ class InternLM2Model(Model): from sentencepiece import SentencePieceProcessor from sentencepiece import sentencepiece_model_pb2 as model - tokenizer_path = self.dir_model / 'tokenizer.model' + tokenizer_path = self.dir_model / "tokenizer.model" tokens: list[bytes] = [] scores: list[float] = [] toktypes: list[int] = [] if not tokenizer_path.is_file(): - print(f'Error: Missing {tokenizer_path}', file=sys.stderr) + print(f"Error: Missing {tokenizer_path}", file=sys.stderr) sys.exit(1) sentencepiece_model = model.ModelProto() @@ -1958,7 +2308,7 @@ class InternLM2Model(Model): add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix tokenizer = SentencePieceProcessor(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size()) for token_id in range(vocab_size): piece = tokenizer.id_to_piece(token_id) @@ -1984,7 +2334,7 @@ class InternLM2Model(Model): scores.append(score) toktypes.append(toktype) - added_tokens_file = self.dir_model / 'added_tokens.json' + added_tokens_file = self.dir_model / "added_tokens.json" if added_tokens_file.is_file(): with open(added_tokens_file, "r", encoding="utf-8") as f: added_tokens_json = json.load(f) @@ -2005,14 +2355,16 @@ class InternLM2Model(Model): if "chat" in os.path.basename(self.dir_model.absolute()): # For the chat model, we replace the eos with '<|im_end|>'. special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer) - print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \ -in chat mode so that the conversation can end normally.") + print( + f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \ +in chat mode so that the conversation can end normally." + ) special_vocab.add_to_gguf(self.gguf_writer) def _try_get_sft_eos(self, tokenizer): - unused_145_list = tokenizer.encode('[UNUSED_TOKEN_145]') - im_end_list = tokenizer.encode('<|im_end|>') + unused_145_list = tokenizer.encode("[UNUSED_TOKEN_145]") + im_end_list = tokenizer.encode("<|im_end|>") assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1) if len(unused_145_list) == 1: eos_token = unused_145_list[0] @@ -2023,9 +2375,13 @@ in chat mode so that the conversation can end normally.") def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int): if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + return ( + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) + .swapaxes(1, 2) + .reshape(weights.shape) + ) def set_gguf_parameters(self): self.gguf_writer.add_name("InternLM2") @@ -2065,7 +2421,12 @@ in chat mode so that the conversation can end normally.") data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -2093,15 +2454,35 @@ in chat mode so that the conversation can end normally.") if re.match(qkv_pattern, name): bid = re.findall(qkv_pattern, name)[0] qkv = data_torch - qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim) - q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :] + qkv = rearrange( + qkv.T, + " o (g n i) ->o g n i", + g=num_groups, + n=q_per_kv + 2, + i=head_dim, + ) + q, k, v = ( + qkv[..., :q_per_kv, :], + qkv[..., q_per_kv : q_per_kv + 1, :], + qkv[..., q_per_kv + 1 : q_per_kv + 2, :], + ) # The model weights of q and k equire additional reshape. - q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads) - k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads) + q = self._hf_permute_qk( + rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads + ) + k = self._hf_permute_qk( + rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads + ) v = rearrange(v, " o g n i -> o (g n i)").T - self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q) - self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k) - self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v) + self.post_write_tensors( + tensor_map, f"model.layers.{bid}.attention.wq.weight", q + ) + self.post_write_tensors( + tensor_map, f"model.layers.{bid}.attention.wk.weight", k + ) + self.post_write_tensors( + tensor_map, f"model.layers.{bid}.attention.wv.weight", v + ) else: self.post_write_tensors(tensor_map, name, data_torch) @@ -2131,7 +2512,9 @@ class BertModel(Model): # get pooling type if pooling_path is not None: - with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: + with open( + self.dir_model / pooling_path / "config.json", encoding="utf-8" + ) as f: pooling = json.load(f) if pooling["pooling_mode_mean_tokens"]: pooling_type = gguf.PoolingType.MEAN @@ -2156,6 +2539,7 @@ class BertModel(Model): if tok.startswith("##"): return tok[2:] return "\u2581" + tok + tokens = list(map(phantom, tokens)) # add vocab to gguf @@ -2172,7 +2556,11 @@ class BertModel(Model): tensors = dict(self.get_tensors()) for name, data_torch in tensors.items(): # we are only using BERT for embeddings so we don't need the pooling layer - if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): + if name in ( + "embeddings.position_ids", + "pooler.dense.weight", + "pooler.dense.bias", + ): continue # we don't need these # map tensor names @@ -2186,8 +2574,11 @@ class BertModel(Model): new_dtype: type[np.floating[Any]] if ( - self.ftype == 1 and name.endswith(".weight") and n_dims == 2 - and name != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32 + self.ftype == 1 + and name.endswith(".weight") + and n_dims == 2 + and name + != "embeddings.token_type_embeddings.weight" # not used with get_rows, must be F32 ): # if f16 desired, convert any float32 2-dim weight tensors to float16 new_dtype = np.float16 @@ -2250,14 +2641,21 @@ class GemmaModel(Model): self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv( + self.hparams["num_key_value_heads"] + if "num_key_value_heads" in hparams + else hparams["num_attention_heads"] + ) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_key_length(hparams["head_dim"]) self.gguf_writer.add_value_length(hparams["head_dim"]) self.gguf_writer.add_file_type(self.ftype) def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + block_count = self.hparams.get( + "n_layers", + self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")), + ) tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) for name, data_torch in self.get_tensors(): @@ -2284,7 +2682,12 @@ class GemmaModel(Model): data = data.astype(np.float32) # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -2315,17 +2718,25 @@ class MambaModel(Model): else: # Use the GPT-NeoX tokenizer when no tokenizer files are present tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf" - print(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'") + print( + f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'" + ) neox_reader = gguf.GGUFReader(tokenizer_path, "r") field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL) self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1])) field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST) - self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size]) + self.gguf_writer.add_token_list( + [bytes(field.parts[i]) for i in field.data][:vocab_size] + ) field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE) - self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size]) + self.gguf_writer.add_token_types( + [field.parts[i].tolist()[0] for i in field.data][:vocab_size] + ) field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES) - self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data]) + self.gguf_writer.add_token_merges( + [bytes(field.parts[i]) for i in field.data] + ) field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID) self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0]) field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID) @@ -2335,23 +2746,37 @@ class MambaModel(Model): def set_gguf_parameters(self): d_model = self.find_hparam(["hidden_size", "d_model"]) - d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 - d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = ( + self.find_hparam(["intermediate_size", "d_inner"], optional=True) + or 2 * d_model + ) d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16 # ceiling division # ref: https://stackoverflow.com/a/17511341/22827863 # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 - dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16) - rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -( + d_model // -16 + ) + rms_norm_eps = ( + self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) + or 1e-5 + ) # Fail early for models which don't have a block expansion factor of 2 assert d_inner == 2 * d_model self.gguf_writer.add_name(self.dir_model.name) - self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_context_length( + 2**20 + ) # arbitrary value; for those who use the default self.gguf_writer.add_embedding_length(d_model) - self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading - self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_feed_forward_length( + 0 + ) # unused, but seemingly required when loading + self.gguf_writer.add_head_count( + 0 + ) # unused, but seemingly required when loading self.gguf_writer.add_block_count(self.hparams["n_layer"]) self.gguf_writer.add_ssm_conv_kernel(d_conv) self.gguf_writer.add_ssm_inner_size(d_inner) @@ -2366,7 +2791,7 @@ class MambaModel(Model): tok_embd = None tok_embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD] + ".weight" - output_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT] + ".weight" + output_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT] + ".weight" for name, data_torch in self.get_tensors(): old_dtype = data_torch.dtype @@ -2407,8 +2832,17 @@ class MambaModel(Model): data = data.astype(np.float32) # if f16 desired, convert big float32 2-dim weight tensors to float16 - new_weight_name = new_name[:-len(".weight")] if new_name.endswith(".weight") else "" - if self.ftype == 1 and data_dtype == np.float32 and new_weight_name.endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2: + new_weight_name = ( + new_name[: -len(".weight")] if new_name.endswith(".weight") else "" + ) + if ( + self.ftype == 1 + and data_dtype == np.float32 + and new_weight_name.endswith( + (".ssm_in", ".ssm_out", "token_embd", "output") + ) + and n_dims == 2 + ): data = data.astype(np.float16) print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") @@ -2438,25 +2872,36 @@ class CommandR2Model(Model): def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Convert a huggingface model to a GGML compatible file") + description="Convert a huggingface model to a GGML compatible file" + ) parser.add_argument( - "--vocab-only", action="store_true", + "--vocab-only", + action="store_true", help="extract only the vocab", ) parser.add_argument( - "--awq-path", type=Path, default=None, - help="Path to scale awq cache file") + "--awq-path", type=Path, default=None, help="Path to scale awq cache file" + ) parser.add_argument( - "--outfile", type=Path, + "--outfile", + type=Path, help="path to write to; default: based on input", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16"], default="f16", + "--outtype", + type=str, + choices=["f32", "f16"], + default="f16", help="output format - use f32 for float32, f16 for float16", ) - parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") parser.add_argument( - "model", type=Path, + "--bigendian", + action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", + type=Path, help="directory containing model file", ) @@ -2469,8 +2914,9 @@ def main() -> None: dir_model = args.model if args.awq_path: - sys.path.insert(1, str(Path(__file__).parent / 'awq-py')) + sys.path.insert(1, str(Path(__file__).parent / "awq-py")) from awq.apply_awq import add_scale_weights # type: ignore[import-not-found] + tmp_model_path = args.model / "weighted_model" dir_model = tmp_model_path if tmp_model_path.is_dir(): @@ -2482,19 +2928,18 @@ def main() -> None: print(f"Saved weighted model at {tmp_model_path}.") if not dir_model.is_dir(): - print(f'Error: {args.model} is not a directory', file=sys.stderr) + print(f"Error: {args.model} is not a directory", file=sys.stderr) sys.exit(1) ftype_map = { "f32": gguf.GGMLQuantizationType.F32, "f16": gguf.GGMLQuantizationType.F16, } - if args.outfile is not None: fname_out = args.outfile else: # output in the same directory as the model by default - fname_out = dir_model / f'ggml-model-{args.outtype}.gguf' + fname_out = dir_model / f"ggml-model-{args.outtype}.gguf" print(f"Loading model: {dir_model.name}") @@ -2502,7 +2947,9 @@ def main() -> None: with torch.inference_mode(): model_class = Model.from_model_architecture(hparams["architectures"][0]) - model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian) + model_instance = model_class( + dir_model, ftype_map[args.outtype], fname_out, args.bigendian + ) print("Set model parameters") model_instance.set_gguf_parameters() @@ -2520,5 +2967,5 @@ def main() -> None: print(f"Model successfully exported to '{fname_out}'") -if __name__ == '__main__': +if __name__ == "__main__": main()