llama : add support for StarCoder model architectures (#3187)

* add placeholder of starcoder in gguf / llama.cpp

* support convert starcoder weights to gguf

* convert MQA to MHA

* fix ffn_down name

* add LLM_ARCH_STARCODER to llama.cpp

* set head_count_kv = 1

* load starcoder weight

* add max_position_embeddings

* set n_positions to max_positioin_embeddings

* properly load all starcoder params

* fix head count kv

* fix comments

* fix vram calculation for starcoder

* store mqa directly

* add input embeddings handling

* add TBD

* working in cpu, metal buggy

* cleanup useless code

* metal : fix out-of-bounds access in soft_max kernels

* llama : make starcoder graph build more consistent with others

* refactor: cleanup comments a bit

* add other starcoder models: 3B, 7B, 15B

* support-mqa-directly

* fix: remove max_position_embeddings, use n_train_ctx

* Update llama.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update llama.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* fix: switch to space from tab

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Meng Zhang 2023-09-16 03:02:13 +08:00 committed by GitHub
parent 80291a1d02
commit 4fe09dfe66
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 637 additions and 21 deletions

View file

@ -77,13 +77,14 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
class MODEL_ARCH(IntEnum):
LLAMA : int = auto()
FALCON : int = auto()
BAICHUAN:int = auto()
GPT2 : int = auto()
GPTJ : int = auto()
GPTNEOX: int = auto()
MPT : int = auto()
LLAMA : int = auto()
FALCON : int = auto()
BAICHUAN : int = auto()
GPT2 : int = auto()
GPTJ : int = auto()
GPTNEOX : int = auto()
MPT : int = auto()
STARCODER : int = auto()
class MODEL_TENSOR(IntEnum):
@ -107,13 +108,14 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN:"baichuan",
MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN: "baichuan",
MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.STARCODER: "starcoder",
}
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
@ -171,6 +173,18 @@ MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.STARCODER: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.POS_EMBD: "position_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.GPT2: {
# TODO
},