Add [UNTESTED] Baichuan support

This commit is contained in:
Galunid 2023-10-29 01:38:35 +02:00
parent 0ff237105d
commit 8618b4e74c

183
model.py
View file

@ -8,7 +8,9 @@ import contextlib
import numpy as np
from pathlib import Path
from typing import TypeAlias, Any
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
class Model:
def __init__(self, dir_model: Path, ftype: int, fname_out: Path):
@ -43,6 +45,8 @@ class Model:
return gguf.MODEL_ARCH.BLOOM
if arch == "MPTForCausalLM":
return gguf.MODEL_ARCH.MPT
if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return gguf.MODEL_ARCH.BAICHUAN
raise NotImplementedError(f'Architecture "{arch}" not supported!')
def set_vocab(self):
@ -174,6 +178,8 @@ class Model:
return BloomModel
if model_architecture == "MPTForCausalLM":
return MPTModel
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return BaichuanModel
return Model
class StableLMModel(Model):
@ -288,3 +294,180 @@ class MPTModel(Model):
if self.hparams["attn_config"]["clip_qkv"] is not None:
self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"])
self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"])
class BaichuanModel(Model):
def set_vocab(self):
from sentencepiece import SentencePieceProcessor # type: ignore[import]
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []
tokenizer_model_file = self.dir_model / 'tokenizer.model'
if not tokenizer_model_file.is_file():
print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
sys.exit(1)
# vocab type sentencepiece
print("gguf: get sentencepiece tokenizer vocab, scores and token types")
tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
vocab_size = self.hparams.get('vocab_size')
if vocab_size is None:
vocab_size = tokenizer.vocab_size()
for i in range(vocab_size):
text: bytes
score: float
piece = tokenizer.id_to_piece(i)
text = piece.encode("utf-8")
score = tokenizer.get_score(i)
toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i):
toktype = 2
if tokenizer.is_control(i):
toktype = 3
# toktype = 4 is user-defined = tokens from added_tokens.json
if tokenizer.is_unused(i):
toktype = 5
if tokenizer.is_byte(i):
toktype = 6
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
addtokens_json = json.load(f)
print("gguf: get added tokens")
for key in addtokens_json:
tokens.append( key.encode("utf-8") )
scores.append(-1000.0)
toktypes.append(4) # user-defined token type
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab = len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
head_count = self.hparams["num_attention_heads"]
if "num_key_value_heads" in self.hparams:
head_count_kv = self.hparams["num_key_value_heads"]
else:
head_count_kv = head_count
if "_name_or_path" in self.hparams:
hf_repo = self.hparams["_name_or_path"]
else:
hf_repo = ""
ctx_length = 0
if "max_sequence_length" in self.hparams:
ctx_length = self.hparams["max_sequence_length"]
elif "max_position_embeddings" in self.hparams:
ctx_length = self.hparams["max_position_embeddings"]
elif "model_max_length" in self.hparams:
ctx_length = self.hparams["model_max_length"]
else:
print("gguf: can not find ctx length parameter.")
sys.exit()
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_source_hf_repo(hf_repo)
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
self.gguf_writer.add_context_length(ctx_length)
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
if "rope_scaling" in self.hparams and self.hparams["rope_scaling"] != None and "factor" in self.hparams["rope_scaling"]:
if "type" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"]["type"] == "linear":
self.gguf_writer.add_rope_scale_linear(self.hparams["rope_scaling"]["factor"])
def _reverse_hf_permute(self, weights: NDArray, n_head: int, n_kv_head: int | None = None) -> NDArray:
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
def _reverse_hf_permute_part(self, weights: NDArray, n_part: int, n_head: int, n_head_kv: int| None = None) -> NDArray:
r = weights.shape[0] // 3
return (self._reverse_hf_permute(weights[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
def _reverse_hf_part(self, weights: NDArray, n_part: int) -> NDArray:
r = weights.shape[0] // 3
return weights[r * n_part : r * n_part + r, ...]
def write_tensors(self):
# Collect tensors from generator object
model_kv = dict(self.get_tensors())
block_count = self.hparams["num_hidden_layers"]
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for i in range(block_count):
if f"model.layers.{i}.self_attn.W_pack.weight" in model_kv:
print(f"Unpacking and permuting layer {i}")
model_kv[f"model.layers.{i}.self_attn.q_proj.weight"] = self._reverse_hf_permute_part(model_kv[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count)
model_kv[f"model.layers.{i}.self_attn.k_proj.weight"] = self._reverse_hf_permute_part(model_kv[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv)
model_kv[f"model.layers.{i}.self_attn.v_proj.weight"] = self._reverse_hf_part(model_kv[f"model.layers.{i}.self_attn.W_pack.weight"],2)
del model_kv[f"model.layers.{i}.self_attn.W_pack.weight"]
for name, data in model_kv.items():
# we don't need these
if name.endswith(".rotary_emb.inv_freq"):
continue
old_dtype = data.dtype
# convert any unsupported data types to float32
if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float32)
data = data.squeeze().numpy()
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(name + " -> " + new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
self.gguf_writer.add_tensor(new_name, data)