diff --git a/model.py b/model.py index 8be445cb8..9291666c5 100644 --- a/model.py +++ b/model.py @@ -8,7 +8,9 @@ import contextlib import numpy as np from pathlib import Path +from typing import TypeAlias, Any +NDArray: TypeAlias = 'np.ndarray[Any, Any]' class Model: def __init__(self, dir_model: Path, ftype: int, fname_out: Path): @@ -43,6 +45,8 @@ class Model: return gguf.MODEL_ARCH.BLOOM if arch == "MPTForCausalLM": return gguf.MODEL_ARCH.MPT + if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): + return gguf.MODEL_ARCH.BAICHUAN raise NotImplementedError(f'Architecture "{arch}" not supported!') def set_vocab(self): @@ -174,6 +178,8 @@ class Model: return BloomModel if model_architecture == "MPTForCausalLM": return MPTModel + if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): + return BaichuanModel return Model class StableLMModel(Model): @@ -288,3 +294,180 @@ class MPTModel(Model): if self.hparams["attn_config"]["clip_qkv"] is not None: self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"]) self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"]) + +class BaichuanModel(Model): + def set_vocab(self): + from sentencepiece import SentencePieceProcessor # type: ignore[import] + tokens: list[bytes] = [] + scores: list[float] = [] + toktypes: list[int] = [] + + tokenizer_model_file = self.dir_model / 'tokenizer.model' + if not tokenizer_model_file.is_file(): + print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr) + sys.exit(1) + + # vocab type sentencepiece + print("gguf: get sentencepiece tokenizer vocab, scores and token types") + + tokenizer = SentencePieceProcessor(str(tokenizer_model_file)) + vocab_size = self.hparams.get('vocab_size') + if vocab_size is None: + vocab_size = tokenizer.vocab_size() + + for i in range(vocab_size): + text: bytes + score: float + + piece = tokenizer.id_to_piece(i) + text = piece.encode("utf-8") + score = tokenizer.get_score(i) + + toktype = 1 # defualt to normal token type + if tokenizer.is_unknown(i): + toktype = 2 + if tokenizer.is_control(i): + toktype = 3 + + # toktype = 4 is user-defined = tokens from added_tokens.json + + if tokenizer.is_unused(i): + toktype = 5 + if tokenizer.is_byte(i): + toktype = 6 + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + added_tokens_file = self.dir_model / 'added_tokens.json' + if added_tokens_file.is_file(): + with open(added_tokens_file, "r", encoding="utf-8") as f: + addtokens_json = json.load(f) + + print("gguf: get added tokens") + + for key in addtokens_json: + tokens.append( key.encode("utf-8") ) + scores.append(-1000.0) + toktypes.append(4) # user-defined token type + + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab = len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + block_count = self.hparams["num_hidden_layers"] + head_count = self.hparams["num_attention_heads"] + + if "num_key_value_heads" in self.hparams: + head_count_kv = self.hparams["num_key_value_heads"] + else: + head_count_kv = head_count + + if "_name_or_path" in self.hparams: + hf_repo = self.hparams["_name_or_path"] + else: + hf_repo = "" + + ctx_length = 0 + if "max_sequence_length" in self.hparams: + ctx_length = self.hparams["max_sequence_length"] + elif "max_position_embeddings" in self.hparams: + ctx_length = self.hparams["max_position_embeddings"] + elif "model_max_length" in self.hparams: + ctx_length = self.hparams["model_max_length"] + else: + print("gguf: can not find ctx length parameter.") + sys.exit() + + self.gguf_writer.add_name(self.dir_model.name) + self.gguf_writer.add_source_hf_repo(hf_repo) + self.gguf_writer.add_tensor_data_layout("Meta AI original pth") + self.gguf_writer.add_context_length(ctx_length) + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(block_count) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count(head_count) + self.gguf_writer.add_head_count_kv(head_count_kv) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + + if "rope_scaling" in self.hparams and self.hparams["rope_scaling"] != None and "factor" in self.hparams["rope_scaling"]: + if "type" in self.hparams["rope_scaling"]: + if self.hparams["rope_scaling"]["type"] == "linear": + self.gguf_writer.add_rope_scale_linear(self.hparams["rope_scaling"]["factor"]) + + + def _reverse_hf_permute(self, weights: NDArray, n_head: int, n_kv_head: int | None = None) -> NDArray: + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + def _reverse_hf_permute_part(self, weights: NDArray, n_part: int, n_head: int, n_head_kv: int| None = None) -> NDArray: + r = weights.shape[0] // 3 + return (self._reverse_hf_permute(weights[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) + + def _reverse_hf_part(self, weights: NDArray, n_part: int) -> NDArray: + r = weights.shape[0] // 3 + return weights[r * n_part : r * n_part + r, ...] + + def write_tensors(self): + # Collect tensors from generator object + model_kv = dict(self.get_tensors()) + block_count = self.hparams["num_hidden_layers"] + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + + for i in range(block_count): + if f"model.layers.{i}.self_attn.W_pack.weight" in model_kv: + print(f"Unpacking and permuting layer {i}") + model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = self._reverse_hf_permute_part(model_kv[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count) + model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = self._reverse_hf_permute_part(model_kv[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv) + model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = self._reverse_hf_part(model_kv[f"model.layers.{i}.self_attn.W_pack.weight"],2) + del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"] + + for name, data in model_kv.items(): + # we don't need these + if name.endswith(".rotary_emb.inv_freq"): + continue + + old_dtype = data.dtype + + # convert any unsupported data types to float32 + if data.dtype != torch.float16 and data.dtype != torch.float32: + data = data.to(torch.float32) + + data = data.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) + if new_name is None: + print("Can not map tensor '" + name + "'") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(name + " -> " + new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) + self.gguf_writer.add_tensor(new_name, data) +