This commit is contained in:
Bingxuan Wang 2023-11-16 17:02:22 +08:00
parent 9cecb7a613
commit 4494a9f655

View file

@ -175,7 +175,7 @@ class Model:
@staticmethod @staticmethod
def from_model_name(model_name: str): def from_model_name(model_name: str):
model_name_lower = model_name.lower() model_name_lower = model_name.lower()
if model_name_lower == "stablelmepoch": if model_name_lower in ("stablelmepoch", "llavastablelmepoch"):
return StableLMModel return StableLMModel
if model_name_lower == "gptneox": if model_name_lower == "gptneox":
return GPTNeoXModel return GPTNeoXModel
@ -183,7 +183,7 @@ class Model:
return BloomModel return BloomModel
if model_name_lower == "mpt": if model_name_lower == "mpt":
return MPTModel return MPTModel
if model_name_lower in ("baichuan", "baichuan"): if model_name_lower in ("baichuan"):
return BaichuanModel return BaichuanModel
if model_name_lower in ("falcon", "rw"): if model_name_lower in ("falcon", "rw"):
return FalconModel return FalconModel
@ -195,8 +195,6 @@ class Model:
return PersimmonModel return PersimmonModel
if model_name_lower == "deepseekcoder": if model_name_lower == "deepseekcoder":
return DeepseekCoderModel return DeepseekCoderModel
if model_name_lower == "stablelm":
return StableLMModel
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool: