fix
This commit is contained in:
parent
9cecb7a613
commit
4494a9f655
1 changed files with 2 additions and 4 deletions
|
@ -175,7 +175,7 @@ class Model:
|
|||
@staticmethod
|
||||
def from_model_name(model_name: str):
|
||||
model_name_lower = model_name.lower()
|
||||
if model_name_lower == "stablelmepoch":
|
||||
if model_name_lower in ("stablelmepoch", "llavastablelmepoch"):
|
||||
return StableLMModel
|
||||
if model_name_lower == "gptneox":
|
||||
return GPTNeoXModel
|
||||
|
@ -183,7 +183,7 @@ class Model:
|
|||
return BloomModel
|
||||
if model_name_lower == "mpt":
|
||||
return MPTModel
|
||||
if model_name_lower in ("baichuan", "baichuan"):
|
||||
if model_name_lower in ("baichuan"):
|
||||
return BaichuanModel
|
||||
if model_name_lower in ("falcon", "rw"):
|
||||
return FalconModel
|
||||
|
@ -195,8 +195,6 @@ class Model:
|
|||
return PersimmonModel
|
||||
if model_name_lower == "deepseekcoder":
|
||||
return DeepseekCoderModel
|
||||
if model_name_lower == "stablelm":
|
||||
return StableLMModel
|
||||
return Model
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue