merged the changes from deepseeker models to main branch
This commit is contained in:
parent
83b72cb086
commit
6fbab2dbc8
15 changed files with 886 additions and 151 deletions
|
@ -215,6 +215,78 @@ class Model(ABC):
|
|||
except KeyError:
|
||||
raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
|
||||
|
||||
@staticmethod
|
||||
def from_model_architecture(model_architecture):
|
||||
if model_architecture == "GPTNeoXForCausalLM":
|
||||
return GPTNeoXModel
|
||||
if model_architecture == "BloomForCausalLM":
|
||||
return BloomModel
|
||||
if model_architecture == "MPTForCausalLM":
|
||||
return MPTModel
|
||||
if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
||||
return BaichuanModel
|
||||
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
|
||||
return FalconModel
|
||||
if model_architecture == "GPTBigCodeForCausalLM":
|
||||
return StarCoderModel
|
||||
if model_architecture == "GPTRefactForCausalLM":
|
||||
return RefactModel
|
||||
if model_architecture == "PersimmonForCausalLM":
|
||||
return PersimmonModel
|
||||
if model_architecture == "LlamaForCausalLM":
|
||||
return DeepseekCoderModel
|
||||
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||
return StableLMModel
|
||||
if model_architecture == "QWenLMHeadModel":
|
||||
return QwenModel
|
||||
if model_architecture == "Qwen2ForCausalLM":
|
||||
return Model
|
||||
if model_architecture == "MixtralForCausalLM":
|
||||
return MixtralModel
|
||||
if model_architecture == "GPT2LMHeadModel":
|
||||
return GPT2Model
|
||||
if model_architecture == "PhiForCausalLM":
|
||||
return Phi2Model
|
||||
if model_architecture == "PlamoForCausalLM":
|
||||
return PlamoModel
|
||||
if model_architecture == "CodeShellForCausalLM":
|
||||
return CodeShellModel
|
||||
if model_architecture == "OrionForCausalLM":
|
||||
return OrionModel
|
||||
if model_architecture == "InternLM2ForCausalLM":
|
||||
return InternLM2Model
|
||||
if model_architecture == "MiniCPMForCausalLM":
|
||||
return MiniCPMModel
|
||||
if model_architecture == "BertModel":
|
||||
return BertModel
|
||||
|
||||
@staticmethod
|
||||
def from_model_name(model_name: str):
|
||||
model_name_lower = model_name.lower()
|
||||
if model_name_lower in ("stablelmepoch", "llavastablelmepoch"):
|
||||
return StableLMModel
|
||||
if model_name_lower == "gptneox":
|
||||
return GPTNeoXModel
|
||||
if model_name_lower == "bloom":
|
||||
return BloomModel
|
||||
if model_name_lower == "mpt":
|
||||
return MPTModel
|
||||
if model_name_lower in ("baichuan"):
|
||||
return BaichuanModel
|
||||
if model_name_lower in ("falcon", "rw"):
|
||||
return FalconModel
|
||||
if model_name_lower == "gptbigcode":
|
||||
return StarCoderModel
|
||||
if model_name_lower == "gptrefact":
|
||||
return RefactModel
|
||||
if model_name_lower == "persimmon":
|
||||
return PersimmonModel
|
||||
if model_name_lower == "deepseekcoder":
|
||||
return DeepseekCoderModel
|
||||
if model_name_lower == "deepseekllm":
|
||||
return DeepseekLLMModel
|
||||
return Model
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
||||
|
||||
|
@ -228,6 +300,53 @@ class Model(ABC):
|
|||
return ("pytorch_model.bin",)
|
||||
return (f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin" for n in range(1, self.num_parts + 1))
|
||||
|
||||
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
|
||||
arch = self.hparams["architectures"][0]
|
||||
if arch == "GPTNeoXForCausalLM":
|
||||
return gguf.MODEL_ARCH.GPTNEOX
|
||||
if arch == "BloomForCausalLM":
|
||||
return gguf.MODEL_ARCH.BLOOM
|
||||
if arch == "MPTForCausalLM":
|
||||
return gguf.MODEL_ARCH.MPT
|
||||
if arch in ("BaichuanForCausalLM", "BaiChuanForCausalLM"):
|
||||
return gguf.MODEL_ARCH.BAICHUAN
|
||||
if arch in ("FalconForCausalLM", "RWForCausalLM"):
|
||||
return gguf.MODEL_ARCH.FALCON
|
||||
if arch == "GPTBigCodeForCausalLM":
|
||||
return gguf.MODEL_ARCH.STARCODER
|
||||
if arch == "GPTRefactForCausalLM":
|
||||
return gguf.MODEL_ARCH.REFACT
|
||||
if arch == "PersimmonForCausalLM":
|
||||
return gguf.MODEL_ARCH.PERSIMMON
|
||||
if arch == "LlamaForCausalLM":
|
||||
return gguf.MODEL_ARCH.LLAMA
|
||||
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
|
||||
return gguf.MODEL_ARCH.STABLELM
|
||||
if arch == "QWenLMHeadModel":
|
||||
return gguf.MODEL_ARCH.QWEN
|
||||
if arch == "Qwen2ForCausalLM":
|
||||
return gguf.MODEL_ARCH.QWEN2
|
||||
if arch == "MixtralForCausalLM":
|
||||
return gguf.MODEL_ARCH.LLAMA
|
||||
if arch == "GPT2LMHeadModel":
|
||||
return gguf.MODEL_ARCH.GPT2
|
||||
if arch == "PhiForCausalLM":
|
||||
return gguf.MODEL_ARCH.PHI2
|
||||
if arch == "PlamoForCausalLM":
|
||||
return gguf.MODEL_ARCH.PLAMO
|
||||
if arch == "CodeShellForCausalLM":
|
||||
return gguf.MODEL_ARCH.CODESHELL
|
||||
if arch == "OrionForCausalLM":
|
||||
return gguf.MODEL_ARCH.ORION
|
||||
if arch == "InternLM2ForCausalLM":
|
||||
return gguf.MODEL_ARCH.INTERNLM2
|
||||
if arch == "MiniCPMForCausalLM":
|
||||
return gguf.MODEL_ARCH.MINICPM
|
||||
if arch == "BertModel":
|
||||
return gguf.MODEL_ARCH.BERT
|
||||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
# used for GPT-2 BPE and WordPiece vocabs
|
||||
def get_basic_vocab(self) -> tuple[list[str], list[int]]:
|
||||
tokens: list[str] = []
|
||||
|
@ -257,9 +376,10 @@ class Model(ABC):
|
|||
|
||||
return tokens, toktypes
|
||||
|
||||
def _set_vocab_gpt2(self) -> None:
|
||||
|
||||
def _set_vocab_gpt2(self, tokenizer_model:str = "gpt2") -> None:
|
||||
tokens, toktypes = self.get_basic_vocab()
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_model(tokenizer_model)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
|
@ -1192,7 +1312,29 @@ class PersimmonModel(Model):
|
|||
n_dims = len(data.shape)
|
||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
@Model.register("LlamaForCausalLM")
|
||||
class DeepseekCoderModel(Model):
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "linear":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2("deepseek_coder")
|
||||
|
||||
class DeepseekLLMModel(DeepseekCoderModel):
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2("deepseek_llm")
|
||||
|
||||
@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
|
||||
class StableLMModel(Model):
|
||||
|
@ -2843,6 +2985,7 @@ def parse_args() -> argparse.Namespace:
|
|||
help="directory containing model file",
|
||||
)
|
||||
parser.add_argument("--use-temp-file", action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)")
|
||||
parser.add_argument("--model-name", type=str, default=None, help="name of the model")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue