Add JAIS
model(s)
This commit is contained in:
parent
26a39bbd6b
commit
34300a03bc
8 changed files with 279 additions and 21 deletions
|
@ -86,6 +86,7 @@ models = [
|
||||||
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
|
{"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", },
|
||||||
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
|
{"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", },
|
||||||
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
|
{"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B
|
||||||
|
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -427,9 +427,6 @@ class Model:
|
||||||
# NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script
|
# NOTE: if you get an error here, you need to update the convert-hf-to-gguf-update.py script
|
||||||
# or pull the latest version of the model from Huggingface
|
# or pull the latest version of the model from Huggingface
|
||||||
# don't edit the hashes manually!
|
# don't edit the hashes manually!
|
||||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
|
||||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
|
||||||
res = "llama-bpe"
|
|
||||||
if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754":
|
if chkhsh == "049ecf7629871e3041641907f3de7c733e4dbfdc736f57d882ba0b0845599754":
|
||||||
# ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base
|
# ref: https://huggingface.co/deepseek-ai/deepseek-llm-7b-base
|
||||||
res = "deepseek-llm"
|
res = "deepseek-llm"
|
||||||
|
@ -457,18 +454,12 @@ class Model:
|
||||||
if chkhsh == "6221ad2852e85ce96f791f476e0b390cf9b474c9e3d1362f53a24a06dc8220ff":
|
if chkhsh == "6221ad2852e85ce96f791f476e0b390cf9b474c9e3d1362f53a24a06dc8220ff":
|
||||||
# ref: https://huggingface.co/smallcloudai/Refact-1_6-base
|
# ref: https://huggingface.co/smallcloudai/Refact-1_6-base
|
||||||
res = "refact"
|
res = "refact"
|
||||||
if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8":
|
|
||||||
# ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01
|
|
||||||
res = "command-r"
|
|
||||||
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
|
if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea":
|
||||||
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
|
# ref: https://huggingface.co/Qwen/Qwen1.5-7B
|
||||||
res = "qwen2"
|
res = "qwen2"
|
||||||
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
if chkhsh == "b6dc8df998e1cfbdc4eac8243701a65afe638679230920b50d6f17d81c098166":
|
||||||
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
|
# ref: https://huggingface.co/allenai/OLMo-1.7-7B-hf
|
||||||
res = "olmo"
|
res = "olmo"
|
||||||
if chkhsh == "a8594e3edff7c29c003940395316294b2c623e09894deebbc65f33f1515df79e":
|
|
||||||
# ref: https://huggingface.co/databricks/dbrx-base
|
|
||||||
res = "dbrx"
|
|
||||||
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f":
|
||||||
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
|
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-en
|
||||||
res = "jina-v2-en"
|
res = "jina-v2-en"
|
||||||
|
@ -490,6 +481,9 @@ class Model:
|
||||||
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
|
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
|
||||||
# ref: https://huggingface.co/LumiOpen/Viking-7B
|
# ref: https://huggingface.co/LumiOpen/Viking-7B
|
||||||
res = "viking"
|
res = "viking"
|
||||||
|
if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
|
||||||
|
# ref: https://huggingface.co/core42/jais-13b
|
||||||
|
res = "jais"
|
||||||
|
|
||||||
if res is None:
|
if res is None:
|
||||||
logger.warning("\n")
|
logger.warning("\n")
|
||||||
|
@ -2817,6 +2811,79 @@ class DeepseekV2Model(Model):
|
||||||
if len(experts) > 0:
|
if len(experts) > 0:
|
||||||
raise ValueError(f"Unprocessed experts: {experts}")
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
@Model.register("JAISLMHeadModel")
|
||||||
|
class JaisModel(Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.JAIS
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# SwigLU activation
|
||||||
|
assert self.hparams["activation_function"] == "swiglu"
|
||||||
|
# ALiBi position embedding
|
||||||
|
assert self.hparams["position_embedding_type"] == "alibi"
|
||||||
|
|
||||||
|
# Embeddings scale
|
||||||
|
self.embeddings_scale = 1.0
|
||||||
|
# note: For some JAIS flavors, output is tied to (same as) wte in original model
|
||||||
|
self.output_is_wte = False
|
||||||
|
if 'mup_embeddings_scale' in self.hparams:
|
||||||
|
self.output_is_wte = True # Hack (?)
|
||||||
|
self.embeddings_scale = self.hparams['mup_embeddings_scale']
|
||||||
|
elif 'embeddings_scale' in self.hparams:
|
||||||
|
self.embeddings_scale = self.hparams['embeddings_scale']
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
self.width_scale = 1.0
|
||||||
|
if 'mup_output_alpha' in self.hparams:
|
||||||
|
assert 'mup_width_scale' in self.hparams
|
||||||
|
self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale']
|
||||||
|
elif 'width_scale' in self.hparams:
|
||||||
|
self.width_scale = self.hparams['width_scale']
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
self.gguf_writer.add_name(self.dir_model.name)
|
||||||
|
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||||
|
self.gguf_writer.add_context_length(self.hparams["n_positions"])
|
||||||
|
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
|
||||||
|
self.gguf_writer.add_feed_forward_length(self.hparams["n_inner"])
|
||||||
|
self.gguf_writer.add_head_count(self.hparams["n_head"])
|
||||||
|
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||||
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
del bid # unused
|
||||||
|
|
||||||
|
tensors: list[tuple[str, Tensor]] = []
|
||||||
|
|
||||||
|
# we don't need these
|
||||||
|
if name.endswith((".attn.bias", "relative_pe.slopes")):
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")):
|
||||||
|
data_torch = data_torch.transpose(1, 0)
|
||||||
|
|
||||||
|
new_name = self.map_tensor_name(name)
|
||||||
|
|
||||||
|
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
||||||
|
tensors.append((new_name, data_torch * self.embeddings_scale))
|
||||||
|
if self.output_is_wte:
|
||||||
|
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
|
||||||
|
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
||||||
|
assert not self.output_is_wte
|
||||||
|
tensors.append((new_name, data_torch * self.width_scale))
|
||||||
|
else:
|
||||||
|
tensors.append((new_name, data_torch))
|
||||||
|
|
||||||
|
return tensors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Model.register("T5ForConditionalGeneration")
|
@Model.register("T5ForConditionalGeneration")
|
||||||
@Model.register("T5WithLMHeadModel")
|
@Model.register("T5WithLMHeadModel")
|
||||||
|
|
|
@ -733,7 +733,6 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// Console/Stream Output
|
// Console/Stream Output
|
||||||
fprintf(stdout, "%s", token_str.c_str());
|
fprintf(stdout, "%s", token_str.c_str());
|
||||||
|
|
||||||
// Record Displayed Tokens To Log
|
// Record Displayed Tokens To Log
|
||||||
// Note: Generated tokens are created one by one hence this check
|
// Note: Generated tokens are created one by one hence this check
|
||||||
if (embd.size() > 1) {
|
if (embd.size() > 1) {
|
||||||
|
|
|
@ -13516,13 +13516,13 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
wp[i] += slope*mp_f32[i];
|
wp[i] += slope*mp_f32[i];
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int i = 0; i < nc; ++i) {
|
for (int i = 0; i < nc; ++i) {
|
||||||
//printf("p[%d] = %f\n", i, p[i]);
|
|
||||||
assert(!isnan(wp[i]));
|
assert(!isnan(wp[i]));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -161,6 +161,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
DEEPSEEK2 = auto()
|
DEEPSEEK2 = auto()
|
||||||
BITNET = auto()
|
BITNET = auto()
|
||||||
T5 = auto()
|
T5 = auto()
|
||||||
|
JAIS = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
|
@ -285,6 +286,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||||
MODEL_ARCH.BITNET: "bitnet",
|
MODEL_ARCH.BITNET: "bitnet",
|
||||||
MODEL_ARCH.T5: "t5",
|
MODEL_ARCH.T5: "t5",
|
||||||
|
MODEL_ARCH.JAIS: "jais",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
|
@ -951,6 +953,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.ENC_FFN_UP,
|
MODEL_TENSOR.ENC_FFN_UP,
|
||||||
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
MODEL_TENSOR.ENC_OUTPUT_NORM,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.JAIS: [
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.OUTPUT,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_QKV,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
],
|
||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ class TensorNameMap:
|
||||||
# Token embeddings
|
# Token embeddings
|
||||||
MODEL_TENSOR.TOKEN_EMBD: (
|
MODEL_TENSOR.TOKEN_EMBD: (
|
||||||
"gpt_neox.embed_in", # gptneox
|
"gpt_neox.embed_in", # gptneox
|
||||||
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
|
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
|
||||||
"transformer.word_embeddings", # falcon
|
"transformer.word_embeddings", # falcon
|
||||||
"word_embeddings", # bloom
|
"word_embeddings", # bloom
|
||||||
"model.embed_tokens", # llama-hf
|
"model.embed_tokens", # llama-hf
|
||||||
|
@ -49,7 +49,7 @@ class TensorNameMap:
|
||||||
# Output
|
# Output
|
||||||
MODEL_TENSOR.OUTPUT: (
|
MODEL_TENSOR.OUTPUT: (
|
||||||
"embed_out", # gptneox
|
"embed_out", # gptneox
|
||||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
|
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
|
||||||
"output", # llama-pth bloom internlm2
|
"output", # llama-pth bloom internlm2
|
||||||
"word_embeddings_for_head", # persimmon
|
"word_embeddings_for_head", # persimmon
|
||||||
"lm_head.linear", # phi2
|
"lm_head.linear", # phi2
|
||||||
|
@ -58,7 +58,7 @@ class TensorNameMap:
|
||||||
# Output norm
|
# Output norm
|
||||||
MODEL_TENSOR.OUTPUT_NORM: (
|
MODEL_TENSOR.OUTPUT_NORM: (
|
||||||
"gpt_neox.final_layer_norm", # gptneox
|
"gpt_neox.final_layer_norm", # gptneox
|
||||||
"transformer.ln_f", # gpt2 gpt-j falcon
|
"transformer.ln_f", # gpt2 gpt-j falcon jais
|
||||||
"model.norm", # llama-hf baichuan internlm2
|
"model.norm", # llama-hf baichuan internlm2
|
||||||
"norm", # llama-pth
|
"norm", # llama-pth
|
||||||
"transformer.norm_f", # mpt dbrx
|
"transformer.norm_f", # mpt dbrx
|
||||||
|
@ -81,7 +81,7 @@ class TensorNameMap:
|
||||||
# Attention norm
|
# Attention norm
|
||||||
MODEL_TENSOR.ATTN_NORM: (
|
MODEL_TENSOR.ATTN_NORM: (
|
||||||
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
|
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
|
||||||
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
|
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
|
||||||
"transformer.blocks.{bid}.norm_1", # mpt
|
"transformer.blocks.{bid}.norm_1", # mpt
|
||||||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||||
"h.{bid}.input_layernorm", # bloom
|
"h.{bid}.input_layernorm", # bloom
|
||||||
|
@ -109,7 +109,7 @@ class TensorNameMap:
|
||||||
# Attention query-key-value
|
# Attention query-key-value
|
||||||
MODEL_TENSOR.ATTN_QKV: (
|
MODEL_TENSOR.ATTN_QKV: (
|
||||||
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
|
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
|
||||||
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
|
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais
|
||||||
"transformer.blocks.{bid}.attn.Wqkv", # mpt
|
"transformer.blocks.{bid}.attn.Wqkv", # mpt
|
||||||
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
|
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
|
||||||
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
||||||
|
@ -160,7 +160,7 @@ class TensorNameMap:
|
||||||
# Attention output
|
# Attention output
|
||||||
MODEL_TENSOR.ATTN_OUT: (
|
MODEL_TENSOR.ATTN_OUT: (
|
||||||
"gpt_neox.layers.{bid}.attention.dense", # gptneox
|
"gpt_neox.layers.{bid}.attention.dense", # gptneox
|
||||||
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
|
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais
|
||||||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||||
"h.{bid}.self_attention.dense", # bloom
|
"h.{bid}.self_attention.dense", # bloom
|
||||||
|
@ -202,7 +202,7 @@ class TensorNameMap:
|
||||||
# Feed-forward norm
|
# Feed-forward norm
|
||||||
MODEL_TENSOR.FFN_NORM: (
|
MODEL_TENSOR.FFN_NORM: (
|
||||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||||
"transformer.h.{bid}.ln_2", # gpt2 refact qwen
|
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
|
||||||
"h.{bid}.post_attention_layernorm", # bloom
|
"h.{bid}.post_attention_layernorm", # bloom
|
||||||
"transformer.blocks.{bid}.norm_2", # mpt
|
"transformer.blocks.{bid}.norm_2", # mpt
|
||||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf
|
"model.layers.{bid}.post_attention_layernorm", # llama-hf
|
||||||
|
@ -239,7 +239,7 @@ class TensorNameMap:
|
||||||
# Feed-forward up
|
# Feed-forward up
|
||||||
MODEL_TENSOR.FFN_UP: (
|
MODEL_TENSOR.FFN_UP: (
|
||||||
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
||||||
"transformer.h.{bid}.mlp.c_fc", # gpt2
|
"transformer.h.{bid}.mlp.c_fc", # gpt2 jais
|
||||||
"transformer.blocks.{bid}.ffn.up_proj", # mpt
|
"transformer.blocks.{bid}.ffn.up_proj", # mpt
|
||||||
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
||||||
"h.{bid}.mlp.dense_h_to_4h", # bloom
|
"h.{bid}.mlp.dense_h_to_4h", # bloom
|
||||||
|
@ -285,6 +285,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
|
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
|
||||||
"layers.{bid}.feed_forward.w1", # llama-pth
|
"layers.{bid}.feed_forward.w1", # llama-pth
|
||||||
"transformer.h.{bid}.mlp.w2", # qwen
|
"transformer.h.{bid}.mlp.w2", # qwen
|
||||||
|
"transformer.h.{bid}.mlp.c_fc2", # jais
|
||||||
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||||
"model.layers.{bid}.feed_forward.w1", # internlm2
|
"model.layers.{bid}.feed_forward.w1", # internlm2
|
||||||
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
|
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
|
||||||
|
@ -308,7 +309,7 @@ class TensorNameMap:
|
||||||
# Feed-forward down
|
# Feed-forward down
|
||||||
MODEL_TENSOR.FFN_DOWN: (
|
MODEL_TENSOR.FFN_DOWN: (
|
||||||
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
|
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
|
||||||
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen
|
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais
|
||||||
"transformer.blocks.{bid}.ffn.down_proj", # mpt
|
"transformer.blocks.{bid}.ffn.down_proj", # mpt
|
||||||
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
||||||
"h.{bid}.mlp.dense_4h_to_h", # bloom
|
"h.{bid}.mlp.dense_4h_to_h", # bloom
|
||||||
|
|
|
@ -89,6 +89,7 @@ extern "C" {
|
||||||
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
||||||
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
||||||
LLAMA_VOCAB_PRE_TYPE_VIKING = 16,
|
LLAMA_VOCAB_PRE_TYPE_VIKING = 16,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_JAIS = 17,
|
||||||
};
|
};
|
||||||
|
|
||||||
// note: these values should be synchronized with ggml_rope
|
// note: these values should be synchronized with ggml_rope
|
||||||
|
|
177
src/llama.cpp
177
src/llama.cpp
|
@ -228,6 +228,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_DEEPSEEK2,
|
LLM_ARCH_DEEPSEEK2,
|
||||||
LLM_ARCH_BITNET,
|
LLM_ARCH_BITNET,
|
||||||
LLM_ARCH_T5,
|
LLM_ARCH_T5,
|
||||||
|
LLM_ARCH_JAIS,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -269,6 +270,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||||
{ LLM_ARCH_BITNET, "bitnet" },
|
{ LLM_ARCH_BITNET, "bitnet" },
|
||||||
{ LLM_ARCH_T5, "t5" },
|
{ LLM_ARCH_T5, "t5" },
|
||||||
|
{ LLM_ARCH_JAIS, "jais" },
|
||||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1230,6 +1232,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||||
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
|
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_JAIS,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_OUTPUT, "output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
{
|
{
|
||||||
|
@ -2029,6 +2046,7 @@ enum e_model {
|
||||||
MODEL_410M,
|
MODEL_410M,
|
||||||
MODEL_0_5B,
|
MODEL_0_5B,
|
||||||
MODEL_1B,
|
MODEL_1B,
|
||||||
|
MODEL_1_3B,
|
||||||
MODEL_1_4B,
|
MODEL_1_4B,
|
||||||
MODEL_2B,
|
MODEL_2B,
|
||||||
MODEL_2_8B,
|
MODEL_2_8B,
|
||||||
|
@ -4880,6 +4898,18 @@ static void llm_load_hparams(
|
||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_JAIS:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||||
|
// TODO: become GGUF KV parameter
|
||||||
|
hparams.f_max_alibi_bias = 8.0f;
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 24: model.type = e_model::MODEL_1_3B; break;
|
||||||
|
case 40: model.type = e_model::MODEL_13B; break;
|
||||||
|
/* TODO: add variants */
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5111,6 +5141,9 @@ static void llm_load_vocab(
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "viking") {
|
tokenizer_pre == "viking") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "jais") {
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
}
|
}
|
||||||
|
@ -6908,7 +6941,6 @@ static bool llm_load_tensors(
|
||||||
case LLM_ARCH_BITNET:
|
case LLM_ARCH_BITNET:
|
||||||
{
|
{
|
||||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||||
|
|
||||||
// output
|
// output
|
||||||
{
|
{
|
||||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
|
@ -6943,6 +6975,43 @@ static bool llm_load_tensors(
|
||||||
layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1});
|
layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1});
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_JAIS:
|
||||||
|
{
|
||||||
|
// Output
|
||||||
|
{
|
||||||
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
|
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
|
||||||
|
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||||
|
}
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||||
|
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||||
|
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
|
||||||
|
|
||||||
|
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
|
||||||
|
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
|
||||||
|
|
||||||
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||||
|
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
|
||||||
|
|
||||||
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||||
|
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
|
||||||
|
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||||
|
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd});
|
||||||
|
|
||||||
|
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||||
|
layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff});
|
||||||
|
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||||
|
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
||||||
|
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown architecture");
|
throw std::runtime_error("unknown architecture");
|
||||||
}
|
}
|
||||||
|
@ -12307,6 +12376,107 @@ struct llm_build_context {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_jais() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
//struct ggml_tensor * pos;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
|
||||||
|
// // inp_pos - contains the positions
|
||||||
|
// struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
|
// pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
||||||
|
// cb(pos, "pos_embd", -1);
|
||||||
|
|
||||||
|
// inpL = ggml_add(ctx0, inpL, pos);
|
||||||
|
// cb(inpL, "inpL", -1);
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm,
|
||||||
|
model.layers[il].attn_norm_b,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||||
|
cb(cur, "wqkv", il);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||||
|
cb(cur, "bqkv", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*cur->nb[0]*(n_embd)));
|
||||||
|
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd)));
|
||||||
|
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)));
|
||||||
|
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
|
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||||
|
|
||||||
|
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||||
|
model.layers[il].wo, model.layers[il].bo,
|
||||||
|
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/float(n_embd_head), cb, il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1) {
|
||||||
|
// skip computing output for unused tokens
|
||||||
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the input
|
||||||
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// FF
|
||||||
|
{
|
||||||
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
|
model.layers[il].ffn_norm,
|
||||||
|
model.layers[il].ffn_norm_b,
|
||||||
|
LLM_NORM, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
cur = llm_build_ffn(ctx0, cur,
|
||||||
|
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||||
|
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||||
|
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||||
|
NULL,
|
||||||
|
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
inpL = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(inpL, "l_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.output_norm,
|
||||||
|
model.output_norm_b,
|
||||||
|
LLM_NORM, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||||
|
@ -12538,6 +12708,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
{
|
{
|
||||||
result = llm.build_bitnet();
|
result = llm.build_bitnet();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_JAIS:
|
||||||
|
{
|
||||||
|
result = llm.build_jais();
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
@ -17760,6 +17934,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
case LLM_ARCH_JINA_BERT_V2:
|
case LLM_ARCH_JINA_BERT_V2:
|
||||||
case LLM_ARCH_T5:
|
case LLM_ARCH_T5:
|
||||||
|
case LLM_ARCH_JAIS:
|
||||||
return LLAMA_ROPE_TYPE_NONE;
|
return LLAMA_ROPE_TYPE_NONE;
|
||||||
|
|
||||||
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue