diff --git a/model.py b/model.py index 2ebd6881e..c8ae1d51a 100644 --- a/model.py +++ b/model.py @@ -153,7 +153,7 @@ class Model: return MPTModel if model_architecture in ("BaichuanForCausalLM", "BaiChuanForCausalLM"): return BaichuanModel - if model_architecture == "FalconForCausalLM": + if model_architecture in ("FalconForCausalLM", "RWForCausalLM"): return FalconModel if model_architecture == "GPTBigCodeForCausalLM": return StarCoderModel