convert : add phi-3 support
This commit is contained in:
parent
4e96a812b3
commit
1732737232
2 changed files with 38 additions and 3 deletions
|
@ -62,6 +62,7 @@ class Model(ABC):
|
|||
def model_arch(self) -> gguf.MODEL_ARCH:
|
||||
pass
|
||||
|
||||
# TODO: add "default" argument
|
||||
def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any:
|
||||
key = next((k for k in keys if k in self.hparams), None)
|
||||
if key is not None:
|
||||
|
@ -89,7 +90,12 @@ class Model(ABC):
|
|||
yield name, data
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
if (mtype := self.find_hparam(["model_type"], optional=True)) is not None:
|
||||
self.gguf_writer.add_name(mtype)
|
||||
print(f"gguf: model type = {mtype}")
|
||||
else:
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
|
||||
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
|
||||
|
@ -363,6 +369,13 @@ class Model(ABC):
|
|||
scores.append(-1000.0)
|
||||
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
|
||||
|
||||
# pad remaining tokens
|
||||
for i in range(vocab_size - len(tokens)):
|
||||
print(f"gguf: padding token {i}")
|
||||
tokens.append(f"[PAD{i}]")
|
||||
scores.append(-1000.0)
|
||||
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
|
||||
|
||||
assert len(tokens) == vocab_size
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("llama")
|
||||
|
@ -1293,7 +1306,7 @@ class StableLMModel(Model):
|
|||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
|
||||
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "Phi3ForCausalLM")
|
||||
class LlamaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
|
||||
|
@ -1322,6 +1335,7 @@ class LlamaModel(Model):
|
|||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
|
||||
|
||||
|
@ -1329,11 +1343,31 @@ class LlamaModel(Model):
|
|||
def write_tensors(self):
|
||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||
n_embd = self.hparams.get("hidden_size")
|
||||
n_head = self.hparams.get("num_attention_heads")
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
n_experts = self.hparams.get("num_local_experts")
|
||||
experts = dict()
|
||||
for name, data_torch in self.get_tensors():
|
||||
|
||||
head_dim = n_embd // n_head
|
||||
|
||||
tensors = dict(self.get_tensors())
|
||||
for i in range(block_count):
|
||||
# Phi-3 transformations
|
||||
# ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/8b29aca7bb785d6336fc19819b045bc7bc584b06/modeling_phi3.py#L379-L384
|
||||
if (w := tensors.get(f"model.layers.{i}.self_attn.qkv_proj.weight")) is not None:
|
||||
qpos = n_head * head_dim
|
||||
tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w[:qpos]
|
||||
tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[qpos:qpos + n_kv_head * head_dim]
|
||||
tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[qpos + n_kv_head * head_dim:]
|
||||
del tensors[f"model.layers.{i}.self_attn.qkv_proj.weight"]
|
||||
if (w := tensors.get(f"model.layers.{i}.mlp.gate_up_proj.weight")) is not None:
|
||||
ff_dim = w.shape[0] // 2
|
||||
tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
|
||||
tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
|
||||
del tensors[f"model.layers.{i}.mlp.gate_up_proj.weight"]
|
||||
|
||||
for name, data_torch in tensors.items():
|
||||
# we don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
|
|
@ -4352,6 +4352,7 @@ static void llm_load_vocab(
|
|||
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
|
||||
(t.first == "<|eot_id|>" ||
|
||||
t.first == "<|im_end|>" ||
|
||||
t.first == "<|end|>" ||
|
||||
t.first == "<end_of_turn>"
|
||||
)
|
||||
) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue