Added support for the snowflake-arctic model.

This commit is contained in:
Stanisław Szymczyk 2024-05-01 09:43:19 +02:00
parent c4ec9c0d3d
commit 71d8bd6480
4 changed files with 447 additions and 3 deletions

View file

@ -1516,6 +1516,119 @@ class LlamaModel(Model):
if len(experts) > 0: if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts.keys()}") raise ValueError(f"Unprocessed experts: {experts.keys()}")
@Model.register("ArcticForCausalLM")
class ArcticModel(Model):
model_arch = gguf.MODEL_ARCH.ARCTIC
def set_vocab(self):
self._set_vocab_llama_hf()
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
# Same as super class, but permuting q_proj, k_proj
def write_tensors(self):
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
n_head = self.hparams.get("num_attention_heads")
n_kv_head = self.hparams.get("num_key_value_heads")
n_experts = self.hparams.get("num_local_experts")
experts = dict()
for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
continue
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.numpy()
if name.endswith("q_proj.weight"):
data = permute(data, n_head, n_head)
if name.endswith("k_proj.weight"):
data = permute(data, n_head, n_kv_head)
data = data.squeeze()
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
experts[name] = data
if len(experts) >= n_experts:
# merge the experts into a single 3d tensor
for bid in range(block_count):
for wid in range(1, 4):
full = True
for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
if ename not in experts:
full = False
break
if not full:
continue
datas = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
datas.append(experts[ename])
del experts[ename]
data = np.stack(datas, axis=0)
data_dtype = data.dtype
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
if self.ftype == 1 and data_dtype == np.float32:
data = data.astype(np.float16)
merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight"
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
continue
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# 1d tensors need to be converted to float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts.keys()}")
@Model.register("GrokForCausalLM") @Model.register("GrokForCausalLM")
class GrokModel(Model): class GrokModel(Model):

View file

@ -138,6 +138,7 @@ class MODEL_ARCH(IntEnum):
COMMAND_R = auto() COMMAND_R = auto()
DBRX = auto() DBRX = auto()
OLMO = auto() OLMO = auto()
ARCTIC = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -180,6 +181,7 @@ class MODEL_TENSOR(IntEnum):
SSM_A = auto() SSM_A = auto()
SSM_D = auto() SSM_D = auto()
SSM_OUT = auto() SSM_OUT = auto()
FFN_NORM_EXP = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -215,6 +217,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.ARCTIC: "arctic",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -257,6 +260,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -725,6 +729,27 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.ARCTIC: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_NORM_EXP,
],
# TODO # TODO
} }

View file

@ -370,6 +370,64 @@ class TensorNameMap:
"model.layers.{bid}.out_proj", "model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj", "backbone.layers.{bid}.mixer.out_proj",
), ),
}
# architecture-specific block mappings
arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
MODEL_ARCH.ARCTIC: {
MODEL_TENSOR.TOKEN_EMBD: (
"model.embed_tokens",
),
MODEL_TENSOR.OUTPUT_NORM: (
"model.norm",
),
MODEL_TENSOR.OUTPUT: (
"lm_head",
),
MODEL_TENSOR.ATTN_NORM: (
"model.layers.{bid}.input_layernorm",
),
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj",
),
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj",
),
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj",
),
MODEL_TENSOR.ATTN_OUT: (
"model.layers.{bid}.self_attn.o_proj",
),
MODEL_TENSOR.FFN_GATE_INP: (
"model.layers.{bid}.block_sparse_moe.gate",
),
MODEL_TENSOR.FFN_NORM: (
"model.layers.{bid}.residual_layernorm",
),
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.residual_mlp.w1",
),
MODEL_TENSOR.FFN_DOWN: (
"model.layers.{bid}.residual_mlp.w2",
),
MODEL_TENSOR.FFN_UP: (
"model.layers.{bid}.residual_mlp.w3",
),
MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1",
),
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.w2",
),
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3",
),
MODEL_TENSOR.FFN_NORM_EXP: (
"model.layers.{bid}.post_attention_layernorm",
),
},
} }
mapping: dict[str, tuple[MODEL_TENSOR, str]] mapping: dict[str, tuple[MODEL_TENSOR, str]]
@ -383,12 +441,16 @@ class TensorNameMap:
self.mapping[tensor_name] = (tensor, tensor_name) self.mapping[tensor_name] = (tensor, tensor_name)
for key in keys: for key in keys:
self.mapping[key] = (tensor, tensor_name) self.mapping[key] = (tensor, tensor_name)
if arch in self.arch_block_mappings_cfg:
block_mappings = self.arch_block_mappings_cfg[arch]
else:
block_mappings = self.block_mappings_cfg
for bid in range(n_blocks): for bid in range(n_blocks):
for tensor, keys in self.block_mappings_cfg.items(): for tensor, keys in block_mappings.items():
if tensor not in MODEL_TENSORS[arch]: if tensor not in MODEL_TENSORS[arch]:
continue continue
# TODO: make this configurable # TODO: make this configurable
n_experts = 60 n_experts = 128
for xid in range(n_experts): for xid in range(n_experts):
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
self.mapping[tensor_name] = (tensor, tensor_name) self.mapping[tensor_name] = (tensor, tensor_name)

246
llama.cpp
View file

@ -106,7 +106,7 @@
#endif #endif
#define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_NODES 8192
#define LLAMA_MAX_EXPERTS 60 #define LLAMA_MAX_EXPERTS 128
// //
// logging // logging
@ -224,6 +224,7 @@ enum llm_arch {
LLM_ARCH_COMMAND_R, LLM_ARCH_COMMAND_R,
LLM_ARCH_DBRX, LLM_ARCH_DBRX,
LLM_ARCH_OLMO, LLM_ARCH_OLMO,
LLM_ARCH_ARCTIC,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -260,6 +261,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_DBRX, "dbrx" },
{ LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_OLMO, "olmo" },
{ LLM_ARCH_ARCTIC, "arctic" },
{ LLM_ARCH_UNKNOWN, "(unknown)" }, { LLM_ARCH_UNKNOWN, "(unknown)" },
}; };
@ -457,6 +459,7 @@ enum llm_tensor {
LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_UP_EXPS,
LLM_TENSOR_FFN_NORM_EXPS,
LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP,
LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP,
LLM_TENSOR_FFN_UP_SHEXP, LLM_TENSOR_FFN_UP_SHEXP,
@ -1027,6 +1030,28 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_ARCTIC,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {
@ -1803,6 +1828,7 @@ enum e_model {
MODEL_8x7B, MODEL_8x7B,
MODEL_8x22B, MODEL_8x22B,
MODEL_16x12B, MODEL_16x12B,
MODEL_10B_128x3_66B,
}; };
static const size_t kiB = 1024; static const size_t kiB = 1024;
@ -1975,6 +2001,7 @@ struct llama_layer {
struct ggml_tensor * ffn_norm_b; struct ggml_tensor * ffn_norm_b;
struct ggml_tensor * layer_out_norm; struct ggml_tensor * layer_out_norm;
struct ggml_tensor * layer_out_norm_b; struct ggml_tensor * layer_out_norm_b;
struct ggml_tensor * ffn_norm_exps;
// ff // ff
struct ggml_tensor * ffn_gate; // w1 struct ggml_tensor * ffn_gate; // w1
@ -3734,6 +3761,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_8x7B: return "8x7B"; case MODEL_8x7B: return "8x7B";
case MODEL_8x22B: return "8x22B"; case MODEL_8x22B: return "8x22B";
case MODEL_16x12B: return "16x12B"; case MODEL_16x12B: return "16x12B";
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
default: return "?B"; default: return "?B";
} }
} }
@ -4196,6 +4224,20 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_ARCTIC:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
if (hparams.n_expert == 128) {
switch (hparams.n_layer) {
case 35: model.type = e_model::MODEL_10B_128x3_66B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} else {
model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0; default: (void)0;
} }
@ -5932,6 +5974,55 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
} }
} break; } break;
case LLM_ARCH_ARCTIC:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
// if output is NULL, init from the input tok embed
if (model.output == NULL) {
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
ml.n_created--; // artificial tensor
ml.size_data += ggml_nbytes(model.output);
}
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
// optional bias tensors
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, false);
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false);
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false);
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, false);
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd});
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
layer.ffn_norm_exps = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd});
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert});
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
}
} break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
} }
@ -10682,6 +10773,154 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_arctic() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
// mutable variable, needed during the last layer of the computation to skip unused tokens
int32_t n_tokens = this->n_tokens;
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
struct ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp);
cb(ffn_out, "ffn_out", il);
// MoE
cur = llm_build_norm(ctx0, inpSA, hparams,
model.layers[il].ffn_norm_exps, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm_exps", il);
cur = llm_build_moe_ffn(ctx0, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
cb, il);
cb(cur, "ffn_moe_out", il);
cur = ggml_add(ctx0, cur, ffn_out);
cb(cur, "ffn_out", il);
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
if (layer_dir != nullptr) {
cur = ggml_add(ctx0, cur, layer_dir);
}
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
}; };
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) { static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@ -10895,6 +11134,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_olmo(); result = llm.build_olmo();
} break; } break;
case LLM_ARCH_ARCTIC:
{
result = llm.build_arctic();
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -15783,6 +16026,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_XVERSE: case LLM_ARCH_XVERSE:
case LLM_ARCH_COMMAND_R: case LLM_ARCH_COMMAND_R:
case LLM_ARCH_OLMO: case LLM_ARCH_OLMO:
case LLM_ARCH_ARCTIC:
return LLAMA_ROPE_TYPE_NORM; return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2 // the pairs of head values are offset by n_rot/2