fix
This commit is contained in:
parent
9cecb7a613
commit
4494a9f655
1 changed files with 2 additions and 4 deletions
|
@ -175,7 +175,7 @@ class Model:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model_name(model_name: str):
|
def from_model_name(model_name: str):
|
||||||
model_name_lower = model_name.lower()
|
model_name_lower = model_name.lower()
|
||||||
if model_name_lower == "stablelmepoch":
|
if model_name_lower in ("stablelmepoch", "llavastablelmepoch"):
|
||||||
return StableLMModel
|
return StableLMModel
|
||||||
if model_name_lower == "gptneox":
|
if model_name_lower == "gptneox":
|
||||||
return GPTNeoXModel
|
return GPTNeoXModel
|
||||||
|
@ -183,7 +183,7 @@ class Model:
|
||||||
return BloomModel
|
return BloomModel
|
||||||
if model_name_lower == "mpt":
|
if model_name_lower == "mpt":
|
||||||
return MPTModel
|
return MPTModel
|
||||||
if model_name_lower in ("baichuan", "baichuan"):
|
if model_name_lower in ("baichuan"):
|
||||||
return BaichuanModel
|
return BaichuanModel
|
||||||
if model_name_lower in ("falcon", "rw"):
|
if model_name_lower in ("falcon", "rw"):
|
||||||
return FalconModel
|
return FalconModel
|
||||||
|
@ -195,8 +195,6 @@ class Model:
|
||||||
return PersimmonModel
|
return PersimmonModel
|
||||||
if model_name_lower == "deepseekcoder":
|
if model_name_lower == "deepseekcoder":
|
||||||
return DeepseekCoderModel
|
return DeepseekCoderModel
|
||||||
if model_name_lower == "stablelm":
|
|
||||||
return StableLMModel
|
|
||||||
return Model
|
return Model
|
||||||
|
|
||||||
def _is_model_safetensors(self) -> bool:
|
def _is_model_safetensors(self) -> bool:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue