Add JAIS model(s) (#8118)

* Add `JAIS` model(s)

* cleanup

* address review comments

* remove hack

* un-hardcode max-alibi-bias

* minor tweaks

---------

Co-authored-by: fmz <quic_fzaghlou@quic.com>
This commit is contained in:
Faisal Zaghloul 2024-07-02 10:36:00 -04:00 committed by GitHub
parent 023b8807e1
commit 968967376d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 288 additions and 9 deletions

View file

@ -164,6 +164,7 @@ class MODEL_ARCH(IntEnum):
DEEPSEEK2 = auto()
BITNET = auto()
T5 = auto()
JAIS = auto()
class MODEL_TENSOR(IntEnum):
@ -288,6 +289,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.JAIS: "jais",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -954,6 +956,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ENC_FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.JAIS: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP,
],
# TODO
}