This commit is contained in:
Bingxuan Wang 2023-11-16 17:02:22 +08:00
parent 9cecb7a613
commit 4494a9f655

View file

@ -175,7 +175,7 @@ class Model:
@staticmethod
def from_model_name(model_name: str):
model_name_lower = model_name.lower()
if model_name_lower == "stablelmepoch":
if model_name_lower in ("stablelmepoch", "llavastablelmepoch"):
return StableLMModel
if model_name_lower == "gptneox":
return GPTNeoXModel
@ -183,7 +183,7 @@ class Model:
return BloomModel
if model_name_lower == "mpt":
return MPTModel
if model_name_lower in ("baichuan", "baichuan"):
if model_name_lower in ("baichuan"):
return BaichuanModel
if model_name_lower in ("falcon", "rw"):
return FalconModel
@ -195,8 +195,6 @@ class Model:
return PersimmonModel
if model_name_lower == "deepseekcoder":
return DeepseekCoderModel
if model_name_lower == "stablelm":
return StableLMModel
return Model
def _is_model_safetensors(self) -> bool: