Compare commits

...
Sign in to create a new pull request.

2 commits

Author SHA1 Message Date
Georgi Gerganov
5dcccb3a7d
convert : fix tokenizer conversion
ref: https://github.com/ggerganov/llama.cpp/pull/6852
2024-04-23 22:11:09 +03:00
Georgi Gerganov
1732737232
convert : add phi-3 support 2024-04-23 20:38:51 +03:00
2 changed files with 46 additions and 12 deletions

View file

@ -62,6 +62,7 @@ class Model(ABC):
def model_arch(self) -> gguf.MODEL_ARCH:
pass
# TODO: add "default" argument
def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any:
key = next((k for k in keys if k in self.hparams), None)
if key is not None:
@ -89,7 +90,12 @@ class Model(ABC):
yield name, data
def set_gguf_parameters(self):
if (mtype := self.find_hparam(["model_type"], optional=True)) is not None:
self.gguf_writer.add_name(mtype)
print(f"gguf: model type = {mtype}")
else:
self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_block_count(self.block_count)
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
@ -332,7 +338,12 @@ class Model(ABC):
tokenizer = SentencePieceProcessor(str(tokenizer_path))
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.id_to_piece(token_id)
text = piece.encode("utf-8")
score = tokenizer.get_score(token_id)
@ -347,9 +358,9 @@ class Model(ABC):
elif tokenizer.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
@ -357,13 +368,14 @@ class Model(ABC):
added_tokens_json = json.load(f)
for key in added_tokens_json:
key = key.encode("utf-8")
if key not in tokens:
tokens.append(key)
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
token_id = added_tokens_json[key]
if (token_id >= vocab_size):
print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue
assert len(tokens) == vocab_size
tokens[token_id] = key.encode("utf-8")
scores[token_id] = -1000.0
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
@ -1293,7 +1305,7 @@ class StableLMModel(Model):
self.gguf_writer.add_tensor(new_name, data)
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "Phi3ForCausalLM")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
@ -1322,6 +1334,7 @@ class LlamaModel(Model):
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
@ -1329,11 +1342,31 @@ class LlamaModel(Model):
def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
n_embd = self.hparams.get("hidden_size")
n_head = self.hparams.get("num_attention_heads")
n_kv_head = self.hparams.get("num_key_value_heads")
n_experts = self.hparams.get("num_local_experts")
experts = dict()
for name, data_torch in self.get_tensors():
head_dim = n_embd // n_head
tensors = dict(self.get_tensors())
for i in range(block_count):
# Phi-3 transformations
# ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/8b29aca7bb785d6336fc19819b045bc7bc584b06/modeling_phi3.py#L379-L384
if (w := tensors.get(f"model.layers.{i}.self_attn.qkv_proj.weight")) is not None:
qpos = n_head * head_dim
tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w[:qpos]
tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[qpos:qpos + n_kv_head * head_dim]
tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[qpos + n_kv_head * head_dim:]
del tensors[f"model.layers.{i}.self_attn.qkv_proj.weight"]
if (w := tensors.get(f"model.layers.{i}.mlp.gate_up_proj.weight")) is not None:
ff_dim = w.shape[0] // 2
tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
del tensors[f"model.layers.{i}.mlp.gate_up_proj.weight"]
for name, data_torch in tensors.items():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
continue

View file

@ -4352,6 +4352,7 @@ static void llm_load_vocab(
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
(t.first == "<|eot_id|>" ||
t.first == "<|im_end|>" ||
t.first == "<|end|>" ||
t.first == "<end_of_turn>"
)
) {