llm : add bloom models (#3553)
* feat: Support bloom models * fix(bloom): fix model size --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
0aa6595ae0
commit
02d2875def
3 changed files with 678 additions and 55 deletions
|
@ -88,29 +88,31 @@ class MODEL_ARCH(IntEnum):
|
|||
PERSIMMON : int = auto()
|
||||
REFACT : int = auto()
|
||||
BERT : int = auto()
|
||||
BLOOM : int = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
TOKEN_EMBD : int = auto()
|
||||
TOKEN_TYPES : int = auto()
|
||||
POS_EMBD : int = auto()
|
||||
OUTPUT : int = auto()
|
||||
OUTPUT_NORM : int = auto()
|
||||
ROPE_FREQS : int = auto()
|
||||
ATTN_Q : int = auto()
|
||||
ATTN_K : int = auto()
|
||||
ATTN_V : int = auto()
|
||||
ATTN_QKV : int = auto()
|
||||
ATTN_OUT : int = auto()
|
||||
ATTN_NORM : int = auto()
|
||||
ATTN_NORM_2 : int = auto()
|
||||
ATTN_ROT_EMBD: int = auto()
|
||||
FFN_GATE : int = auto()
|
||||
FFN_DOWN : int = auto()
|
||||
FFN_UP : int = auto()
|
||||
FFN_NORM : int = auto()
|
||||
ATTN_Q_NORM : int = auto()
|
||||
ATTN_K_NORM : int = auto()
|
||||
TOKEN_EMBD : int = auto()
|
||||
TOKEN_EMBD_NORM : int = auto()
|
||||
TOKEN_TYPES : int = auto()
|
||||
POS_EMBD : int = auto()
|
||||
OUTPUT : int = auto()
|
||||
OUTPUT_NORM : int = auto()
|
||||
ROPE_FREQS : int = auto()
|
||||
ATTN_Q : int = auto()
|
||||
ATTN_K : int = auto()
|
||||
ATTN_V : int = auto()
|
||||
ATTN_QKV : int = auto()
|
||||
ATTN_OUT : int = auto()
|
||||
ATTN_NORM : int = auto()
|
||||
ATTN_NORM_2 : int = auto()
|
||||
ATTN_ROT_EMBD : int = auto()
|
||||
FFN_GATE : int = auto()
|
||||
FFN_DOWN : int = auto()
|
||||
FFN_UP : int = auto()
|
||||
FFN_NORM : int = auto()
|
||||
ATTN_Q_NORM : int = auto()
|
||||
ATTN_K_NORM : int = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
|
@ -125,29 +127,31 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.PERSIMMON: "persimmon",
|
||||
MODEL_ARCH.REFACT: "refact",
|
||||
MODEL_ARCH.BERT: "bert",
|
||||
MODEL_ARCH.BLOOM: "bloom",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
|
||||
MODEL_TENSOR.TOKEN_TYPES: "token_types",
|
||||
MODEL_TENSOR.POS_EMBD: "position_embd",
|
||||
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
||||
MODEL_TENSOR.OUTPUT: "output",
|
||||
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
|
||||
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
|
||||
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
|
||||
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
|
||||
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
|
||||
MODEL_TENSOR.TOKEN_TYPES: "token_types",
|
||||
MODEL_TENSOR.POS_EMBD: "position_embd",
|
||||
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
|
||||
MODEL_TENSOR.OUTPUT: "output",
|
||||
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
|
||||
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
|
||||
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
|
||||
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
|
||||
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
|
||||
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
|
||||
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
|
||||
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
|
||||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
|
@ -282,6 +286,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.BLOOM: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.GPT2: [
|
||||
# TODO
|
||||
],
|
||||
|
@ -311,6 +327,7 @@ class TensorNameMap:
|
|||
"gpt_neox.embed_in", # gptneox
|
||||
"transformer.wte", # gpt2 gpt-j mpt refact
|
||||
"transformer.word_embeddings", # falcon
|
||||
"word_embeddings", # bloom
|
||||
"model.embed_tokens", # llama-hf
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert
|
||||
|
@ -322,6 +339,11 @@ class TensorNameMap:
|
|||
"embeddings.token_type_embeddings", # bert
|
||||
),
|
||||
|
||||
# Normalization of token embeddings
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM: (
|
||||
"word_embeddings_layernorm", # bloom
|
||||
),
|
||||
|
||||
# Position embeddings
|
||||
MODEL_TENSOR.POS_EMBD: (
|
||||
"transformer.wpe", # gpt2
|
||||
|
@ -332,7 +354,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.OUTPUT: (
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan
|
||||
"output", # llama-pth
|
||||
"output", # llama-pth bloom
|
||||
"word_embeddings_for_head", # persimmon
|
||||
),
|
||||
|
||||
|
@ -344,7 +366,7 @@ class TensorNameMap:
|
|||
"norm", # llama-pth
|
||||
"embeddings.LayerNorm", # bert
|
||||
"transformer.norm_f", # mpt
|
||||
"ln_f", # refact
|
||||
"ln_f", # refact bloom
|
||||
"language_model.encoder.final_layernorm", # persimmon
|
||||
),
|
||||
|
||||
|
@ -361,6 +383,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact
|
||||
"transformer.blocks.{bid}.norm_1", # mpt
|
||||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"h.{bid}.input_layernorm", # bloom
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
|
@ -379,6 +402,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.blocks.{bid}.attn.Wqkv", # mpt
|
||||
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
||||
"h.{bid}.self_attention.query_key_value", # bloom
|
||||
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
||||
),
|
||||
|
||||
|
@ -412,6 +436,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.attn.c_proj", # gpt2 refact
|
||||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
|
@ -429,6 +454,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.FFN_NORM: (
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||
"transformer.h.{bid}.ln_2", # gpt2 refact
|
||||
"h.{bid}.post_attention_layernorm", # bloom
|
||||
"transformer.blocks.{bid}.norm_2", # mpt
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf
|
||||
"layers.{bid}.ffn_norm", # llama-pth
|
||||
|
@ -442,6 +468,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.mlp.c_fc", # gpt2
|
||||
"transformer.blocks.{bid}.ffn.up_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
||||
"h.{bid}.mlp.dense_h_to_4h", # bloom
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact
|
||||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
|
@ -461,6 +488,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact
|
||||
"transformer.blocks.{bid}.ffn.down_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
||||
"h.{bid}.mlp.dense_4h_to_h", # bloom
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf
|
||||
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||
"encoder.layer.{bid}.output.dense", # bert
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue