llama : towards llama3 tokenization support (wip)

This commit is contained in:
Georgi Gerganov 2024-04-26 14:55:03 +03:00
parent ed42711b90
commit 4907e41aa7
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
8 changed files with 298 additions and 121 deletions

View file

@ -215,50 +215,50 @@ class Model(ABC):
except KeyError:
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
@staticmethod
def from_model_architecture(model_architecture):
if model_architecture == "GPTNeoXForCausalLM":
return GPTNeoXModel
if model_architecture == "BloomForCausalLM":
return BloomModel
if model_architecture == "MPTForCausalLM":
return MPTModel
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
return BaichuanModel
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
return FalconModel
if model_architecture == "GPTBigCodeForCausalLM":
return StarCoderModel
if model_architecture == "GPTRefactForCausalLM":
return RefactModel
if model_architecture == "PersimmonForCausalLM":
return PersimmonModel
if model_architecture == "LlamaForCausalLM":
return DeepseekCoderModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel
if model_architecture == "QWenLMHeadModel":
return QwenModel
if model_architecture == "Qwen2ForCausalLM":
return Model
if model_architecture == "MixtralForCausalLM":
return MixtralModel
if model_architecture == "GPT2LMHeadModel":
return GPT2Model
if model_architecture == "PhiForCausalLM":
return Phi2Model
if model_architecture == "PlamoForCausalLM":
return PlamoModel
if model_architecture == "CodeShellForCausalLM":
return CodeShellModel
if model_architecture == "OrionForCausalLM":
return OrionModel
if model_architecture == "InternLM2ForCausalLM":
return InternLM2Model
if model_architecture == "MiniCPMForCausalLM":
return MiniCPMModel
if model_architecture == "BertModel":
return BertModel
# @staticmethod
# def from_model_architecture(model_architecture):
# if model_architecture == "GPTNeoXForCausalLM":
# return GPTNeoXModel
# if model_architecture == "BloomForCausalLM":
# return BloomModel
# if model_architecture == "MPTForCausalLM":
# return MPTModel
# if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
# return BaichuanModel
# if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
# return FalconModel
# if model_architecture == "GPTBigCodeForCausalLM":
# return StarCoderModel
# if model_architecture == "GPTRefactForCausalLM":
# return RefactModel
# if model_architecture == "PersimmonForCausalLM":
# return PersimmonModel
# if model_architecture == "LlamaForCausalLM":
# return LlamaModel
# if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
# return StableLMModel
# if model_architecture == "QWenLMHeadModel":
# return QwenModel
# if model_architecture == "Qwen2ForCausalLM":
# return Model
# if model_architecture == "MixtralForCausalLM":
# return MixtralModel
# if model_architecture == "GPT2LMHeadModel":
# return GPT2Model
# if model_architecture == "PhiForCausalLM":
# return Phi2Model
# if model_architecture == "PlamoForCausalLM":
# return PlamoModel
# if model_architecture == "CodeShellForCausalLM":
# return CodeShellModel
# if model_architecture == "OrionForCausalLM":
# return OrionModel
# if model_architecture == "InternLM2ForCausalLM":
# return InternLM2Model
# if model_architecture == "MiniCPMForCausalLM":
# return MiniCPMModel
# if model_architecture == "BertModel":
# return BertModel
@staticmethod
def from_model_name(model_name: str):
@ -281,10 +281,8 @@ class Model(ABC):
return RefactModel
if model_name_lower == "persimmon":
return PersimmonModel
if model_name_lower == "deepseekcoder":
return DeepseekCoderModel
if model_name_lower == "deepseekllm":
return DeepseekLLMModel
if model_name_lower in ("llama", "deepseekcoder", "deepseekllm"):
return LlamaModel
return Model
def _is_model_safetensors(self) -> bool:
@ -376,7 +374,6 @@ class Model(ABC):
return tokens, toktypes
def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2") -> None:
tokens, toktypes = self.get_basic_vocab()
self.gguf_writer.add_tokenizer_model(tokenizer_model)
@ -1312,31 +1309,7 @@ class PersimmonModel(Model):
n_dims = len(data.shape)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
@Model.register("LlamaForCausalLM")
class DeepseekCoderModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
def set_gguf_parameters(self):
super().set_gguf_parameters()
head_count = self.hparams["num_attention_heads"]
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
def set_vocab(self):
self._set_vocab_gpt2("deepseek_coder")
class DeepseekLLMModel(DeepseekCoderModel):
def set_vocab(self):
self._set_vocab_gpt2("deepseek_llm")
@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
class StableLMModel(Model):
@ -1479,6 +1452,11 @@ class LlamaModel(Model):
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
if self.hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
# Same as super class, but permuting q_proj, k_proj
def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))