llama : add llm_build_norm helper function
ggml-ci
This commit is contained in:
parent
210e6e5d02
commit
7db9c96d8a
1 changed files with 175 additions and 258 deletions
433
llama.cpp
433
llama.cpp
|
@ -972,7 +972,7 @@ struct llama_mlock {
|
|||
|
||||
typedef void (*offload_func_t)(struct ggml_tensor * tensor);
|
||||
|
||||
static void ggml_offload_nop(struct ggml_tensor * tensor) { // don't offload by default
|
||||
static void ggml_offload_nop(struct ggml_tensor * tensor) {
|
||||
(void) tensor;
|
||||
}
|
||||
|
||||
|
@ -3093,6 +3093,42 @@ static bool llama_model_load(
|
|||
|
||||
using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
|
||||
|
||||
enum llm_norm_type {
|
||||
LLM_NORM,
|
||||
LLM_NORM_RMS,
|
||||
};
|
||||
|
||||
static struct ggml_tensor * llm_build_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * cur,
|
||||
struct ggml_tensor * mw,
|
||||
struct ggml_tensor * mb,
|
||||
llm_norm_type type,
|
||||
float eps,
|
||||
const llm_build_cb & cb,
|
||||
int il) {
|
||||
switch (type) {
|
||||
case LLM_NORM: cur = ggml_norm (ctx, cur, eps); break;
|
||||
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, eps); break;
|
||||
};
|
||||
if (mw || mb) {
|
||||
cb(cur, "norm", il);
|
||||
}
|
||||
|
||||
if (mw) {
|
||||
cur = ggml_mul(ctx, cur, mw);
|
||||
if (mb) {
|
||||
cb(cur, "norm_w", il);
|
||||
}
|
||||
}
|
||||
|
||||
if (mb) {
|
||||
cur = ggml_add(ctx, cur, mb);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llm_build_llama(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
|
@ -3192,14 +3228,11 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
||||
cb(cur, "rms_norm_0", il);
|
||||
|
||||
// cur = cur*attn_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
||||
cb(cur, "attn_norm_0", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, norm_rms_eps, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
@ -3307,15 +3340,11 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
|
||||
cb(cur, "rms_norm_1", il);
|
||||
|
||||
// cur = cur*ffn_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
cb(cur, "ffn_norm", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpFF,
|
||||
model.layers[il].ffn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, norm_rms_eps, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w3,
|
||||
|
@ -3349,15 +3378,11 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
|
||||
cur = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
|
||||
cb(cur, "rms_norm_2", -1);
|
||||
|
||||
// cur = cur*norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
cb(cur, "result_norm", -1);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, cur,
|
||||
model.output_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, norm_rms_eps, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
|
@ -3466,15 +3491,11 @@ static struct ggml_cgraph * llm_build_baichaun(
|
|||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
||||
cb(cur, "rms_norm_0", il);
|
||||
|
||||
// cur = cur*attn_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
||||
cb(cur, "attn_norm_0", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, norm_rms_eps, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
@ -3600,15 +3621,11 @@ static struct ggml_cgraph * llm_build_baichaun(
|
|||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
|
||||
cb(cur, "rms_norm_1", il);
|
||||
|
||||
// cur = cur*ffn_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
cb(cur, "ffn_norm", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpFF,
|
||||
model.layers[il].ffn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, norm_rms_eps, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w3,
|
||||
|
@ -3763,27 +3780,21 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
struct ggml_tensor * attn_norm;
|
||||
|
||||
// self-attention
|
||||
// TODO: refactor into common function (shared with LLaMA)
|
||||
{
|
||||
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
|
||||
cb(attn_norm, "attn_norm_0", il);
|
||||
attn_norm = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(attn_norm, "attn_norm", il);
|
||||
|
||||
attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm);
|
||||
cb(attn_norm, "attn_norm_0_w", il);
|
||||
|
||||
attn_norm = ggml_add(ctx0, attn_norm, model.layers[il].attn_norm_b);
|
||||
cb(attn_norm, "attn_norm_0_wb", il);
|
||||
|
||||
if (model.layers[il].attn_norm_2) { // Falcon-40B
|
||||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
if (model.layers[il].attn_norm_2) {
|
||||
// Falcon-40B
|
||||
cur = llm_build_norm(ctx0, attn_norm,
|
||||
model.layers[il].attn_norm_2,
|
||||
model.layers[il].attn_norm_2_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "attn_norm_2", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm_2);
|
||||
cb(cur, "attn_norm_2_w", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_2_b);
|
||||
cb(cur, "attn_norm_2_wb", il);
|
||||
} else { // Falcon 7B
|
||||
} else {
|
||||
cur = attn_norm;
|
||||
}
|
||||
|
||||
|
@ -3925,16 +3936,11 @@ static struct ggml_cgraph * llm_build_falcon(
|
|||
cur = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, norm_eps);
|
||||
cb(cur, "out_norm_0", -1);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
cb(cur, "out_norm_0_w", -1);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.output_norm_b);
|
||||
cb(cur, "result_norm", -1);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, cur,
|
||||
model.output_norm,
|
||||
model.output_norm_b,
|
||||
LLM_NORM, norm_eps, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
@ -4024,17 +4030,11 @@ static struct ggml_cgraph * llm_build_starcoder(
|
|||
cb(inpL, "inpL", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
{
|
||||
// Norm
|
||||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
cb(cur, "attn_norm_0", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
||||
cb(cur, "attn_norm_0_w", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b);
|
||||
cb(cur, "attn_norm_0_wb", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
{
|
||||
// Self Attention
|
||||
|
@ -4130,17 +4130,11 @@ static struct ggml_cgraph * llm_build_starcoder(
|
|||
|
||||
// FF
|
||||
{
|
||||
// Norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpFF, norm_eps);
|
||||
cb(cur, "ffn_norm_0", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
cb(cur, "ffn_norm_0_w", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ffn_norm_b);
|
||||
cb(cur, "ffn_norm_0_wb", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpFF,
|
||||
model.layers[il].ffn_norm,
|
||||
model.layers[il].ffn_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
|
||||
cb(cur, "result_w3", il);
|
||||
|
@ -4161,17 +4155,11 @@ static struct ggml_cgraph * llm_build_starcoder(
|
|||
|
||||
}
|
||||
|
||||
// Output Norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
cb(cur, "out_norm_0", -1);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
cb(cur, "out_norm_0_w", -1);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.output_norm_b);
|
||||
cb(cur, "result_norm", -1);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpL,
|
||||
model.output_norm,
|
||||
model.output_norm_b,
|
||||
LLM_NORM, norm_eps, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
@ -4206,7 +4194,7 @@ static struct ggml_cgraph * llm_build_persimmon(
|
|||
|
||||
const float freq_base = cparams.rope_freq_base;
|
||||
const float freq_scale = cparams.rope_freq_scale;
|
||||
const float norm_eps = hparams.f_norm_eps;
|
||||
const float norm_eps = hparams.f_norm_eps;
|
||||
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
const int32_t n_kv = worst_case ? n_ctx : kv_self.n;
|
||||
|
@ -4271,16 +4259,11 @@ static struct ggml_cgraph * llm_build_persimmon(
|
|||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * residual = inpL;
|
||||
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
cb(cur, "attn_norm_0", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
||||
cb(cur, "attn_norm_0_w", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b);
|
||||
cb(cur, "attn_norm_0_wb", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self attention
|
||||
{
|
||||
|
@ -4316,22 +4299,16 @@ static struct ggml_cgraph * llm_build_persimmon(
|
|||
cb(tmpk, "tmpk", il);
|
||||
|
||||
// Q/K Layernorm
|
||||
tmpq = ggml_norm(ctx0, tmpq, norm_eps);
|
||||
tmpq = llm_build_norm(ctx0, tmpq,
|
||||
model.layers[il].attn_q_norm,
|
||||
model.layers[il].attn_q_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(tmpq, "tmpq", il);
|
||||
|
||||
tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm);
|
||||
cb(tmpq, "tmpq", il);
|
||||
|
||||
tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b);
|
||||
cb(tmpq, "tmpq", il);
|
||||
|
||||
tmpk = ggml_norm(ctx0, tmpk, norm_eps);
|
||||
cb(tmpk, "tmpk", il);
|
||||
|
||||
tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm);
|
||||
cb(tmpk, "tmpk", il);
|
||||
|
||||
tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b);
|
||||
tmpk = llm_build_norm(ctx0, tmpk,
|
||||
model.layers[il].attn_k_norm,
|
||||
model.layers[il].attn_k_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(tmpk, "tmpk", il);
|
||||
|
||||
// RoPE the first n_rot of q/k, pass the other half, and concat.
|
||||
|
@ -4480,17 +4457,11 @@ static struct ggml_cgraph * llm_build_persimmon(
|
|||
|
||||
{
|
||||
// MLP
|
||||
{
|
||||
// Norm
|
||||
cur = ggml_norm(ctx0, inpFF, norm_eps);
|
||||
cb(cur, "ffn_norm_0", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
cb(cur, "ffn_norm_0_w", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ffn_norm_b);
|
||||
cb(cur, "ffn_norm_0_wb", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpFF,
|
||||
model.layers[il].ffn_norm,
|
||||
model.layers[il].ffn_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
|
||||
cb(cur, "result_w3", il);
|
||||
|
@ -4519,16 +4490,11 @@ static struct ggml_cgraph * llm_build_persimmon(
|
|||
|
||||
cur = inpL;
|
||||
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, norm_eps);
|
||||
cb(cur, "out_norm_0", -1);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
cb(cur, "out_norm_0_w", -1);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.output_norm_b);
|
||||
cb(cur, "result_norm", -1);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, cur,
|
||||
model.output_norm,
|
||||
model.output_norm_b,
|
||||
LLM_NORM, norm_eps, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
@ -4609,15 +4575,11 @@ static struct ggml_cgraph * llm_build_refact(
|
|||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
||||
cb(cur, "rms_norm_0", il);
|
||||
|
||||
// cur = cur*attn_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
||||
cb(cur, "attn_norm_0", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, norm_rms_eps, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
@ -4719,15 +4681,11 @@ static struct ggml_cgraph * llm_build_refact(
|
|||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
|
||||
cb(cur, "rms_norm_1", il);
|
||||
|
||||
// cur = cur*ffn_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
cb(cur, "ffn_norm", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpFF,
|
||||
model.layers[il].ffn_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, norm_rms_eps, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w3,
|
||||
|
@ -4761,15 +4719,11 @@ static struct ggml_cgraph * llm_build_refact(
|
|||
|
||||
cur = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
|
||||
cb(cur, "rms_norm_2", -1);
|
||||
|
||||
// cur = cur*norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
cb(cur, "result_norm", -1);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, cur,
|
||||
model.output_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, norm_rms_eps, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
|
@ -4851,30 +4805,18 @@ static struct ggml_cgraph * llm_build_bloom(
|
|||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||
cb(KQ_mask, "KQ_mask", -1);
|
||||
|
||||
// norm
|
||||
{
|
||||
inpL = ggml_norm(ctx0, embd, norm_eps);
|
||||
cb(inpL, "inp_norm", -1);
|
||||
|
||||
inpL = ggml_mul(ctx0, inpL, model.tok_norm);
|
||||
cb(inpL, "inp_norm_w", -1);
|
||||
|
||||
inpL = ggml_add (ctx0, inpL, model.tok_norm_b);
|
||||
cb(inpL, "inp_norm_wb", -1);
|
||||
}
|
||||
inpL = llm_build_norm(ctx0, embd,
|
||||
model.tok_norm,
|
||||
model.tok_norm_b,
|
||||
LLM_NORM, norm_eps, cb, -1);
|
||||
cb(inpL, "inp_norm", -1);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
{
|
||||
// Norm
|
||||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
cb(cur, "attn_norm_0", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
||||
cb(cur, "attn_norm_0_w", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b);
|
||||
cb(cur, "attn_norm_0_wb", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
{
|
||||
// Self Attention
|
||||
|
@ -4984,17 +4926,11 @@ static struct ggml_cgraph * llm_build_bloom(
|
|||
|
||||
// FF
|
||||
{
|
||||
// Norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpFF, norm_eps);
|
||||
cb(cur, "ffn_norm_0", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
cb(cur, "ffn_norm_0_w", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ffn_norm_b);
|
||||
cb(cur, "ffn_norm_0_wb", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpFF,
|
||||
model.layers[il].ffn_norm,
|
||||
model.layers[il].ffn_norm_b,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
|
||||
cb(cur, "result_w3", il);
|
||||
|
@ -5016,17 +4952,11 @@ static struct ggml_cgraph * llm_build_bloom(
|
|||
cb(inpL, "inpFF_+_result_w2", il);
|
||||
}
|
||||
|
||||
// Output Norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpL, norm_eps);
|
||||
cb(cur, "out_norm_0", -1);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
cb(cur, "out_norm_0_w", -1);
|
||||
|
||||
cur = ggml_add(ctx0, cur, model.output_norm_b);
|
||||
cb(cur, "result_norm", -1);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, inpL,
|
||||
model.output_norm,
|
||||
model.output_norm_b,
|
||||
LLM_NORM, norm_eps, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
@ -5109,18 +5039,15 @@ static struct ggml_cgraph * llm_build_mpt(
|
|||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * attn_norm;
|
||||
|
||||
attn_norm = llm_build_norm(ctx0, inpL,
|
||||
model.layers[il].attn_norm,
|
||||
NULL,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(attn_norm, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
// TODO: refactor into common function (shared with LLaMA)
|
||||
{
|
||||
attn_norm = ggml_norm(ctx0, inpL, norm_eps);
|
||||
cb(attn_norm, "attn_norm_0", il);
|
||||
|
||||
attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm);
|
||||
cb(attn_norm, "attn_norm_0_w", il);
|
||||
|
||||
if (1) {
|
||||
cur = attn_norm;
|
||||
}
|
||||
cur = attn_norm;
|
||||
|
||||
// compute QKV
|
||||
|
||||
|
@ -5230,14 +5157,11 @@ static struct ggml_cgraph * llm_build_mpt(
|
|||
|
||||
// feed forward
|
||||
{
|
||||
// Norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, attn_out, norm_eps);
|
||||
cb(cur, "ffn_norm_0", il);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
cb(cur, "ffn_norm_0_w", il);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, attn_out,
|
||||
model.layers[il].ffn_norm,
|
||||
NULL,
|
||||
LLM_NORM, norm_eps, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
|
||||
cb(cur, "result_w3", il);
|
||||
|
@ -5258,14 +5182,11 @@ static struct ggml_cgraph * llm_build_mpt(
|
|||
|
||||
cur = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, norm_eps);
|
||||
cb(cur, "out_norm_0", -1);
|
||||
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
cb(cur, "result_norm", -1);
|
||||
}
|
||||
cur = llm_build_norm(ctx0, cur,
|
||||
model.output_norm,
|
||||
NULL,
|
||||
LLM_NORM, norm_eps, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
@ -5378,15 +5299,12 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
|
|||
{ "inp_norm_w", OFFLOAD_FUNC_NR },
|
||||
{ "inp_norm_wb", OFFLOAD_FUNC_NR },
|
||||
|
||||
{ "rms_norm_0", OFFLOAD_FUNC },
|
||||
|
||||
{ "attn_norm_0", OFFLOAD_FUNC },
|
||||
{ "attn_norm_0_w", OFFLOAD_FUNC },
|
||||
{ "attn_norm_0_wb", OFFLOAD_FUNC },
|
||||
{ "norm", OFFLOAD_FUNC },
|
||||
{ "norm_w", OFFLOAD_FUNC },
|
||||
{ "norm_wb", OFFLOAD_FUNC },
|
||||
|
||||
{ "attn_norm", OFFLOAD_FUNC },
|
||||
{ "attn_norm_2", OFFLOAD_FUNC },
|
||||
{ "attn_norm_2_w", OFFLOAD_FUNC },
|
||||
{ "attn_norm_2_wb", OFFLOAD_FUNC },
|
||||
|
||||
{ "wqkv", OFFLOAD_FUNC_KQ },
|
||||
{ "bqkv", OFFLOAD_FUNC_KQ },
|
||||
|
@ -5614,20 +5532,19 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
|
||||
static const std::unordered_map<llm_offload_func_e, std::string, std::hash<int>> k_offload_func_name = {
|
||||
{ OFFLOAD_FUNC_NOP, "CPU" },
|
||||
{ OFFLOAD_FUNC_OUT, "CPU" },
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
{ OFFLOAD_FUNC, "GPU (CUDA)" },
|
||||
{ OFFLOAD_FUNC_KQ, "GPU (CUDA) KQ" },
|
||||
{ OFFLOAD_FUNC_V, "GPU (CUDA) V" },
|
||||
{ OFFLOAD_FUNC_NR, "GPU (CUDA) NR" },
|
||||
{ OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" },
|
||||
{ OFFLOAD_FUNC_OUT, "GPU (CUDA) OUT" },
|
||||
#else
|
||||
{ OFFLOAD_FUNC, "CPU" },
|
||||
{ OFFLOAD_FUNC_KQ, "CPU" },
|
||||
{ OFFLOAD_FUNC_V, "CPU" },
|
||||
{ OFFLOAD_FUNC_NR, "CPU" },
|
||||
{ OFFLOAD_FUNC_EMB, "CPU" },
|
||||
{ OFFLOAD_FUNC_OUT, "CPU" },
|
||||
#endif // GGML_USE_CUBLAS
|
||||
};
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue