feature : support Baichuan serial models (#3009)
This commit is contained in:
parent
35f73049af
commit
4c8643dd6e
4 changed files with 781 additions and 3 deletions
462
llama.cpp
462
llama.cpp
|
@ -155,6 +155,7 @@ static std::string format(const char * fmt, ...) {
|
|||
enum llm_arch {
|
||||
LLM_ARCH_LLAMA,
|
||||
LLM_ARCH_FALCON,
|
||||
LLM_ARCH_BAICHUAN,
|
||||
LLM_ARCH_GPT2,
|
||||
LLM_ARCH_GPTJ,
|
||||
LLM_ARCH_GPTNEOX,
|
||||
|
@ -169,6 +170,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_GPTJ, "gptj" },
|
||||
{ LLM_ARCH_GPTNEOX, "gptneox" },
|
||||
{ LLM_ARCH_MPT, "mpt" },
|
||||
{ LLM_ARCH_BAICHUAN,"baichuan" },
|
||||
};
|
||||
|
||||
enum llm_kv {
|
||||
|
@ -309,6 +311,25 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_BAICHUAN,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_FALCON,
|
||||
{
|
||||
|
@ -1683,6 +1704,15 @@ static void llm_load_hparams(
|
|||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BAICHUAN:
|
||||
{
|
||||
GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_7B; break;
|
||||
case 40: model.type = e_model::MODEL_13B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
};
|
||||
|
||||
|
@ -1923,7 +1953,6 @@ static void llm_load_tensors(
|
|||
const int64_t n_vocab = hparams.n_vocab;
|
||||
|
||||
const auto tn = LLM_TN(model.arch);
|
||||
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
{
|
||||
|
@ -1966,6 +1995,72 @@ static void llm_load_tensors(
|
|||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
||||
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||
|
||||
layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||
layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split);
|
||||
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
|
||||
|
||||
layer.w1 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
|
||||
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
|
||||
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
|
||||
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) +
|
||||
ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_BAICHUAN:
|
||||
{
|
||||
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
|
||||
{
|
||||
ggml_backend backend_norm;
|
||||
ggml_backend backend_output;
|
||||
|
||||
if (n_gpu_layers > int(n_layer)) {
|
||||
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
|
||||
// on Windows however this is detrimental unless everything is on the GPU
|
||||
#ifndef _WIN32
|
||||
backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
||||
#else
|
||||
backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
||||
#endif // _WIN32
|
||||
|
||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
||||
} else {
|
||||
backend_norm = GGML_BACKEND_CPU;
|
||||
backend_output = GGML_BACKEND_CPU;
|
||||
}
|
||||
|
||||
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
|
||||
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
|
||||
|
||||
if (backend_norm == GGML_BACKEND_GPU) {
|
||||
vram_weights += ggml_nbytes(model.output_norm);
|
||||
}
|
||||
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
|
||||
vram_weights += ggml_nbytes(model.output);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
||||
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
||||
|
@ -2542,6 +2637,367 @@ static struct ggml_cgraph * llm_build_llama(
|
|||
return gf;
|
||||
}
|
||||
|
||||
|
||||
static struct ggml_cgraph * llm_build_baichaun(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past) {
|
||||
|
||||
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
|
||||
|
||||
const int N = n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
GGML_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_ctx = hparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
const float freq_base = hparams.rope_freq_base;
|
||||
const float freq_scale = hparams.rope_freq_scale;
|
||||
const float norm_rms_eps = hparams.f_norm_rms_eps;
|
||||
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
||||
auto & buf_compute = lctx.buf_compute;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.data,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
params.no_alloc = true;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
if (tokens) {
|
||||
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
|
||||
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
||||
}
|
||||
ggml_set_name(inp_tokens, "inp_tokens");
|
||||
|
||||
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
||||
} else {
|
||||
#ifdef GGML_USE_MPI
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
#endif
|
||||
|
||||
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
||||
|
||||
ggml_allocr_alloc(lctx.alloc, inpL);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
||||
}
|
||||
}
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
(void) i_gpu_start;
|
||||
|
||||
// offload functions set the tensor output backend to GPU
|
||||
// tensors are GPU-accelerated if any input or the output has been offloaded
|
||||
//
|
||||
// with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
|
||||
// in that case ggml_cuda_assign_buffers has no effect
|
||||
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
|
||||
offload_func_t offload_func_kq = llama_nop;
|
||||
offload_func_t offload_func_v = llama_nop;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (n_gpu_layers > n_layer) {
|
||||
offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
if (n_gpu_layers > n_layer + 1) {
|
||||
offload_func_v = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
if (n_gpu_layers > n_layer + 2) {
|
||||
offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
}
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_format_name(inpL, "layer_inp_%d", il);
|
||||
|
||||
offload_func_t offload_func = llama_nop;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (il >= i_gpu_start) {
|
||||
offload_func = ggml_cuda_assign_buffers_no_alloc;
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "rms_norm_0");
|
||||
|
||||
// cur = cur*attn_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "attention_norm_0");
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
offload_func_kq(tmpk);
|
||||
ggml_set_name(tmpk, "tmpk");
|
||||
|
||||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
offload_func_kq(tmpq);
|
||||
ggml_set_name(tmpq, "tmpq");
|
||||
|
||||
struct ggml_tensor * Kcur;
|
||||
struct ggml_tensor * Qcur;
|
||||
switch (model.type) {
|
||||
case MODEL_7B:
|
||||
Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
|
||||
break;
|
||||
case MODEL_13B:
|
||||
Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N);
|
||||
Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
offload_func_kq(Kcur);
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
||||
offload_func_kq(Qcur);
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
// compute the transposed [N, n_embd] V matrix
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
offload_func_v(tmpv);
|
||||
ggml_set_name(tmpv, "tmpv");
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
|
||||
offload_func_v(Vcur);
|
||||
ggml_set_name(Vcur, "Vcur");
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
|
||||
offload_func_kq(k);
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
|
||||
offload_func_v(v);
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
offload_func_kq(Q);
|
||||
ggml_set_name(Q, "Q");
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_past + N, n_head_kv,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa,
|
||||
ggml_element_size(kv_self.k)*n_embd_head,
|
||||
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
|
||||
offload_func_kq(K);
|
||||
ggml_set_name(K, "K");
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
offload_func_kq(KQ);
|
||||
ggml_set_name(KQ, "KQ");
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd_head)
|
||||
// KQ_scaled shape [n_past + N, N, n_head, 1]
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
||||
offload_func_kq(KQ_scaled);
|
||||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
struct ggml_tensor * KQ_masked;
|
||||
struct ggml_tensor * KQ_scaled_alibi;
|
||||
|
||||
switch (model.type) {
|
||||
case MODEL_7B:
|
||||
KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
break;
|
||||
case MODEL_13B:
|
||||
KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8);
|
||||
ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
|
||||
KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
// struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
// struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
|
||||
// offload_func_kq(KQ_masked);
|
||||
// ggml_set_name(KQ_masked, "KQ_masked");
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
offload_func_v(KQ_soft_max);
|
||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
|
||||
// split cached V into n_head heads
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
n_past + N, n_embd_head, n_head_kv,
|
||||
ggml_element_size(kv_self.v)*n_ctx,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
|
||||
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
|
||||
offload_func_v(V);
|
||||
ggml_set_name(V, "V");
|
||||
|
||||
#if 1
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
offload_func_v(KQV);
|
||||
ggml_set_name(KQV, "KQV");
|
||||
#else
|
||||
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
|
||||
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
|
||||
// is there a better way?
|
||||
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
|
||||
#endif
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
offload_func_v(KQV_merged);
|
||||
ggml_set_name(KQV_merged, "KQV_merged");
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
offload_func_v(cur);
|
||||
ggml_set_name(cur, "KQV_merged_contiguous");
|
||||
|
||||
// projection (no bias)
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].wo,
|
||||
cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_wo");
|
||||
}
|
||||
|
||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
||||
offload_func(inpFF);
|
||||
ggml_set_name(inpFF, "inpFF");
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "rms_norm_1");
|
||||
|
||||
// cur = cur*ffn_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "ffn_norm");
|
||||
}
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w3,
|
||||
cur);
|
||||
offload_func(tmp);
|
||||
ggml_set_name(tmp, "result_w3");
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w1,
|
||||
cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_w1");
|
||||
|
||||
// SILU activation
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "silu");
|
||||
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "silu_x_result_w3");
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w2,
|
||||
cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_w2");
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpFF);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "inpFF_+_result_w2");
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
|
||||
offload_func_nr(cur);
|
||||
ggml_set_name(cur, "rms_norm_2");
|
||||
|
||||
// cur = cur*norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.output_norm);
|
||||
// offload_func_nr(cur); // TODO CPU + GPU mirrored backend
|
||||
ggml_set_name(cur, "result_norm");
|
||||
}
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
ggml_set_name(cur, "result_output");
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llm_build_falcon(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
|
@ -2864,6 +3320,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
{
|
||||
result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past);
|
||||
} break;
|
||||
case LLM_ARCH_BAICHUAN:
|
||||
{
|
||||
result = llm_build_baichaun(lctx, tokens, embd, n_tokens, n_past);
|
||||
} break;
|
||||
case LLM_ARCH_FALCON:
|
||||
{
|
||||
result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue