Refactor graph building to reduce duplication

This commit is contained in:
KerfuffleV2 2023-10-11 20:04:01 -06:00
parent b8fe4b5cc9
commit aa7fbef78c

571
llama.cpp
View file

@ -2966,191 +2966,271 @@ static bool llama_model_load(
return true; return true;
} }
static struct ggml_cgraph * llm_build_llama( struct llm_build_ctx {
llama_context & lctx, struct ggml_context * ctx0 = nullptr;
const llama_batch & batch) { ggml_cgraph * gf = nullptr;
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self; ggml_allocr * alloc;
bool alloc_measure;
llama_buffer & buf_compute;
GGML_ASSERT(!!kv_self.ctx); const llama_batch & batch;
const int64_t n_embd = hparams.n_embd; const llama_model & model;
const int64_t n_layer = hparams.n_layer; const std::vector<llama_layer> & layers;
const int64_t n_ctx = cparams.n_ctx; const llama_hparams & hparams;
const int64_t n_head = hparams.n_head; const llama_cparams & cparams;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa();
GGML_ASSERT(n_embd_head == hparams.n_rot); const llama_kv_cache & kv_self;
const float freq_base = cparams.rope_freq_base; const int64_t n_embd;
const float freq_scale = cparams.rope_freq_scale; const int64_t n_layer;
const float norm_rms_eps = hparams.f_norm_rms_eps; const int64_t n_ctx;
const int64_t n_head;
const int64_t n_head_kv;
const int64_t n_embd_head;
const int64_t n_embd_gqa;
const int n_gpu_layers = model.n_gpu_layers; const float freq_base;
const float freq_scale;
const float norm_rms_eps;
const int32_t n_tokens = batch.n_tokens; const int n_gpu_layers;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; const int32_t n_tokens;
const int32_t n_kv;
const int32_t kv_head;
//printf("n_kv = %d\n", n_kv); const bool do_rope_shift;
auto & buf_compute = lctx.buf_compute; offload_func_t offload_func = llama_nop;
offload_func_t offload_func_nr = llama_nop;
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_allocr_alloc(lctx.alloc, inp_tokens);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
}
ggml_set_name(inp_tokens, "inp_tokens");
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
ggml_allocr_alloc(lctx.alloc, inpL);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
}
}
const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop; offload_func_t offload_func_v = llama_nop;
#ifdef GGML_USE_CUBLAS llm_build_ctx(llama_context & lctx, const llama_batch & batch)
if (n_gpu_layers > n_layer) { : alloc (lctx.alloc)
offload_func_nr = ggml_cuda_assign_buffers_no_alloc; , alloc_measure (ggml_allocr_is_measure(alloc))
} , buf_compute (lctx.buf_compute)
if (n_gpu_layers > n_layer + 1) { , batch (batch)
offload_func_v = ggml_cuda_assign_buffers_no_alloc; , model (lctx.model)
} , layers (model.layers)
if (n_gpu_layers > n_layer + 2) { , hparams (model.hparams)
offload_func_kq = ggml_cuda_assign_buffers_no_alloc; , cparams (lctx.cparams)
} , kv_self (lctx.kv_self)
#endif // GGML_USE_CUBLAS
// KQ_scale , n_embd (hparams.n_embd)
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); , n_layer (hparams.n_layer)
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); , n_ctx (cparams.n_ctx)
ggml_allocr_alloc(lctx.alloc, KQ_scale); , n_head (hparams.n_head)
if (!ggml_allocr_is_measure(lctx.alloc)) { , n_head_kv (hparams.n_head_kv)
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); , n_embd_head (hparams.n_embd_head())
, n_embd_gqa (hparams.n_embd_gqa())
, freq_base (cparams.rope_freq_base)
, freq_scale (cparams.rope_freq_scale)
, norm_rms_eps (hparams.f_norm_eps)
, n_gpu_layers (model.n_gpu_layers)
, n_tokens (batch.n_tokens)
, n_kv (alloc_measure ? n_ctx : kv_self.n)
, kv_head (alloc_measure ? n_ctx - n_tokens : kv_self.head)
, do_rope_shift (alloc_measure || kv_self.has_shift)
{
GGML_ASSERT(!!kv_self.ctx);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
/*.no_alloc =*/ true,
};
ctx0 = ggml_init(params);
gf = ggml_new_graph(ctx0);
} }
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) ~llm_build_ctx() {
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); ggml_free(ctx0);
offload_func_kq(KQ_mask); ctx0 = nullptr;
ggml_set_name(KQ_mask, "KQ_mask"); gf = nullptr;
ggml_allocr_alloc(lctx.alloc, KQ_mask); }
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
memset(data, 0, ggml_nbytes(KQ_mask));
for (int h = 0; h < 1; ++h) { struct ggml_tensor * build_pre_repeating();
for (int j = 0; j < n_tokens; ++j) { ggml_cgraph * build_post_repeating();
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { struct ggml_tensor * build_attn_block(
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; const int32_t il,
ggml_tensor * input);
struct ggml_tensor * build_ffn_block(
const int32_t il,
ggml_tensor * input);
};
struct llm_build_llama_ctx : llm_build_ctx {
struct ggml_tensor * KQ_pos = nullptr;
struct ggml_tensor * KQ_scale = nullptr;
struct ggml_tensor * KQ_mask = nullptr;
llm_build_llama_ctx(llama_context & lctx, const llama_batch & batch)
: llm_build_ctx(lctx, batch)
{}
struct ggml_tensor *build_pre_repeating() {
struct ggml_tensor * inpL;
if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_allocr_alloc(alloc, inp_tokens);
if (!alloc_measure) {
memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
}
ggml_set_name(inp_tokens, "inp_tokens");
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
ggml_allocr_alloc(alloc, inpL);
if (!alloc_measure) {
memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
}
}
const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start;
// offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > n_layer) {
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 1) {
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
}
if (n_gpu_layers > n_layer + 2) {
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
// KQ_scale
KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
ggml_allocr_alloc(alloc, KQ_scale);
if (!alloc_measure) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head)));
}
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
offload_func_kq(KQ_mask);
ggml_set_name(KQ_mask, "KQ_mask");
ggml_allocr_alloc(alloc, KQ_mask);
if (!alloc_measure) {
float * data = (float *) KQ_mask->data;
memset(data, 0, ggml_nbytes(KQ_mask));
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
}
} }
} }
} }
} }
}
// KQ_pos - contains the positions // KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
offload_func_kq(KQ_pos); offload_func_kq(KQ_pos);
ggml_set_name(KQ_pos, "KQ_pos"); ggml_set_name(KQ_pos, "KQ_pos");
ggml_allocr_alloc(lctx.alloc, KQ_pos); ggml_allocr_alloc(alloc, KQ_pos);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!alloc_measure) {
int * data = (int *) KQ_pos->data; int * data = (int *) KQ_pos->data;
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
data[i] = batch.pos[i]; data[i] = batch.pos[i];
}
}
// shift the entire K-cache if needed
if (do_rope_shift) {
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
offload_func_kq(K_shift);
ggml_set_name(K_shift, "K_shift");
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
for (int i = 0; i < n_ctx; ++i) {
data[i] = kv_self.cells[i].delta;
} }
} }
for (int il = 0; il < n_layer; ++il) { // shift the entire K-cache if needed
struct ggml_tensor * tmp = if (do_rope_shift) {
ggml_rope_custom_inplace(ctx0, struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
ggml_view_3d(ctx0, kv_self.k, offload_func_kq(K_shift);
n_embd_head, n_head_kv, n_ctx, ggml_set_name(K_shift, "K_shift");
ggml_element_size(kv_self.k)*n_embd_head, ggml_allocr_alloc(alloc, K_shift);
ggml_element_size(kv_self.k)*n_embd_gqa, if (!alloc_measure) {
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il), int * data = (int *) K_shift->data;
K_shift, n_embd_head, 0, 0, freq_base, freq_scale); for (int i = 0; i < n_ctx; ++i) {
offload_func_kq(tmp); data[i] = kv_self.cells[i].delta;
ggml_build_forward_expand(gf, tmp); }
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
ggml_rope_custom_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(tmp);
ggml_build_forward_expand(gf, tmp);
}
} }
return inpL;
} }
for (int il = 0; il < n_layer; ++il) { ggml_cgraph * build_post_repeating(ggml_tensor * cur) {
ggml_format_name(inpL, "layer_inp_%d", il); // norm
{
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_2");
offload_func_t offload_func = llama_nop; // cur = cur*norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.output_norm);
#ifdef GGML_USE_CUBLAS // offload_func_nr(cur); // TODO CPU + GPU mirrored backend
if (il >= i_gpu_start) { ggml_set_name(cur, "result_norm");
offload_func = ggml_cuda_assign_buffers_no_alloc;
} }
#endif // GGML_USE_CUBLAS
struct ggml_tensor * inpSA = inpL; // lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output");
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_tensor *build_attn_block(
const int32_t il,
ggml_tensor * input) {
const llama_layer & layer = layers[il];
const size_t v_elsize = ggml_element_size(kv_self.v);
const size_t k_elsize = ggml_element_size(kv_self.k);
ggml_tensor * cur;
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
offload_func(cur);
ggml_set_name(cur, "rms_norm_0");
// cur = cur*attn_norm(broadcasted) // cur = cur*attn_norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm); cur = ggml_mul(ctx0, input, layer.attn_norm);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "attention_norm_0"); ggml_set_name(cur, "attention_norm_0");
} }
@ -3158,19 +3238,23 @@ static struct ggml_cgraph * llm_build_llama(
// self-attention // self-attention
{ {
// compute Q and K and RoPE them // compute Q and K and RoPE them
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, layer.wk, cur);
offload_func_kq(tmpk); offload_func_kq(tmpk);
ggml_set_name(tmpk, "tmpk"); ggml_set_name(tmpk, "tmpk");
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, layer.wq, cur);
offload_func_kq(tmpq); offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq"); ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); struct ggml_tensor * Kcur = ggml_rope_custom(ctx0,
ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens),
KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur); offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur"); ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); struct ggml_tensor * Qcur = ggml_rope_custom(ctx0,
ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens),
KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur); offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur"); ggml_set_name(Qcur, "Qcur");
@ -3178,7 +3262,7 @@ static struct ggml_cgraph * llm_build_llama(
{ {
// compute the transposed [n_tokens, n_embd] V matrix // compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, layer.wv, cur);
offload_func_v(tmpv); offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv"); ggml_set_name(tmpv, "tmpv");
@ -3186,13 +3270,13 @@ static struct ggml_cgraph * llm_build_llama(
offload_func_v(Vcur); offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur"); ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (k_elsize*n_embd_gqa)*(il*n_ctx + kv_head));
offload_func_kq(k); offload_func_kq(k);
ggml_set_name(k, "k"); ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v), ( n_ctx)*v_elsize,
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); (il*n_ctx)*v_elsize*n_embd_gqa + kv_head*v_elsize);
offload_func_v(v); offload_func_v(v);
ggml_set_name(v, "v"); ggml_set_name(v, "v");
@ -3208,9 +3292,9 @@ static struct ggml_cgraph * llm_build_llama(
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k, ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_kv, n_head_kv, n_embd_head, n_kv, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa, k_elsize*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head, k_elsize*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); k_elsize*n_embd_gqa*n_ctx*il);
offload_func_kq(K); offload_func_kq(K);
ggml_set_name(K, "K"); ggml_set_name(K, "K");
@ -3239,23 +3323,15 @@ static struct ggml_cgraph * llm_build_llama(
struct ggml_tensor * V = struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v, ggml_view_3d(ctx0, kv_self.v,
n_kv, n_embd_head, n_head_kv, n_kv, n_embd_head, n_head_kv,
ggml_element_size(kv_self.v)*n_ctx, v_elsize*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head, v_elsize*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); v_elsize*n_ctx*n_embd_gqa*il);
offload_func_v(V); offload_func_v(V);
ggml_set_name(V, "V"); ggml_set_name(V, "V");
#if 1
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
offload_func_v(KQV); offload_func_v(KQV);
ggml_set_name(KQV, "KQV"); ggml_set_name(KQV, "KQV");
#else
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
// is there a better way?
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif
// KQV_merged = KQV.permute(0, 2, 1, 3) // KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
@ -3269,59 +3345,104 @@ static struct ggml_cgraph * llm_build_llama(
// projection (no bias) // projection (no bias)
cur = ggml_mul_mat(ctx0, cur = ggml_mul_mat(ctx0,
model.layers[il].wo, layer.wo,
cur); cur);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "result_wo"); ggml_set_name(cur, "result_wo");
} }
return cur;
}
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); struct ggml_tensor * build_ffn_block(
const int32_t il,
ggml_tensor * input) {
const llama_layer & layer = layers[il];
ggml_tensor * cur;
// norm
{
// cur = cur*ffn_norm(broadcasted)
cur = ggml_mul(ctx0, input, layer.ffn_norm);
offload_func(cur);
ggml_set_name(cur, "ffn_norm");
}
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
layer.w3,
cur);
offload_func(tmp);
ggml_set_name(tmp, "result_w3");
cur = ggml_mul_mat(ctx0,
layer.w1,
cur);
offload_func(cur);
ggml_set_name(cur, "result_w1");
// SILU activation
cur = ggml_silu(ctx0, cur);
offload_func(cur);
ggml_set_name(cur, "silu");
cur = ggml_mul(ctx0, cur, tmp);
offload_func(cur);
ggml_set_name(cur, "silu_x_result_w3");
cur = ggml_mul_mat(ctx0,
layer.w2,
cur);
offload_func(cur);
ggml_set_name(cur, "result_w2");
return cur;
}
};
static struct ggml_cgraph * llm_build_llama(
llama_context & lctx,
const llama_batch & batch) {
llm_build_llama_ctx bctx(lctx, batch);
struct ggml_tensor * cur = nullptr;
struct ggml_tensor * inpL = bctx.build_pre_repeating();
const int i_gpu_start = bctx.n_layer - bctx.n_gpu_layers;
(void) i_gpu_start;
for (int il = 0; il < bctx.n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il);
offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) {
bctx.offload_func = ggml_cuda_assign_buffers_no_alloc;
}
#endif // GGML_USE_CUBLAS
struct ggml_tensor * inpSA = inpL;
// norm
cur = ggml_rms_norm(bctx.ctx0, inpL, bctx.norm_rms_eps);
offload_func(cur);
ggml_set_name(cur, "rms_norm_0");
bctx.offload_func = offload_func;
cur = bctx.build_attn_block(il, cur);
struct ggml_tensor * inpFF = ggml_add(bctx.ctx0, cur, inpSA);
offload_func(inpFF); offload_func(inpFF);
ggml_set_name(inpFF, "inpFF"); ggml_set_name(inpFF, "inpFF");
// feed-forward network // norm
{ cur = ggml_rms_norm(bctx.ctx0, inpFF, bctx.norm_rms_eps);
// norm offload_func(cur);
{ ggml_set_name(cur, "rms_norm_1");
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
offload_func(cur);
ggml_set_name(cur, "rms_norm_1");
// cur = cur*ffn_norm(broadcasted) cur = bctx.build_ffn_block(il, cur);
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
offload_func(cur);
ggml_set_name(cur, "ffn_norm");
}
struct ggml_tensor * tmp = ggml_mul_mat(ctx0, cur = ggml_add(bctx.ctx0, cur, inpFF);
model.layers[il].w3,
cur);
offload_func(tmp);
ggml_set_name(tmp, "result_w3");
cur = ggml_mul_mat(ctx0,
model.layers[il].w1,
cur);
offload_func(cur);
ggml_set_name(cur, "result_w1");
// SILU activation
cur = ggml_silu(ctx0, cur);
offload_func(cur);
ggml_set_name(cur, "silu");
cur = ggml_mul(ctx0, cur, tmp);
offload_func(cur);
ggml_set_name(cur, "silu_x_result_w3");
cur = ggml_mul_mat(ctx0,
model.layers[il].w2,
cur);
offload_func(cur);
ggml_set_name(cur, "result_w2");
}
cur = ggml_add(ctx0, cur, inpFF);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "inpFF_+_result_w2"); ggml_set_name(cur, "inpFF_+_result_w2");
@ -3329,27 +3450,7 @@ static struct ggml_cgraph * llm_build_llama(
inpL = cur; inpL = cur;
} }
cur = inpL; ggml_cgraph * gf = bctx.build_post_repeating(inpL);
// norm
{
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_2");
// cur = cur*norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.output_norm);
// offload_func_nr(cur); // TODO CPU + GPU mirrored backend
ggml_set_name(cur, "result_norm");
}
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output");
ggml_build_forward_expand(gf, cur);
ggml_free(ctx0);
return gf; return gf;
} }