test-backend-ops : add Falcon test
This commit is contained in:
parent
f5ac635473
commit
1849b85473
1 changed files with 265 additions and 88 deletions
|
@ -1471,53 +1471,64 @@ struct test_moe : public test_case {
|
|||
};
|
||||
|
||||
|
||||
// llama
|
||||
struct test_llama : public test_case {
|
||||
const int n_tokens;
|
||||
static constexpr float f_norm_rms_eps = 1e-5;
|
||||
static constexpr int64_t n_embd_k_gqa = 3200;
|
||||
static constexpr int64_t n_embd_v_gqa = 3200;
|
||||
static constexpr int64_t n_ctx = 512;
|
||||
static constexpr int64_t n_layer = 1;
|
||||
static constexpr int64_t n_head = 32;
|
||||
static constexpr int64_t n_head_kv = 32;
|
||||
static constexpr int64_t n_embd_head = 100;
|
||||
static constexpr int64_t n_embd = 3200;
|
||||
static constexpr int64_t n_orig_ctx = n_ctx;
|
||||
static constexpr int64_t n_ff = 8640;
|
||||
static constexpr int64_t n_kv = 32;
|
||||
static constexpr int64_t kv_head = 1;
|
||||
static constexpr float freq_base = 10000.0f;
|
||||
static constexpr float freq_scale = 1.0f;
|
||||
static constexpr float ext_factor = 0.0f;
|
||||
static constexpr float attn_factor = 1.0f;
|
||||
static constexpr float beta_fast = 32.0f;
|
||||
static constexpr float beta_slow = 1.0f;
|
||||
enum llm_norm_type {
|
||||
LLM_NORM,
|
||||
LLM_NORM_RMS,
|
||||
};
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
return "LLAMA";
|
||||
struct llama_hparams {
|
||||
uint32_t n_vocab;
|
||||
uint32_t n_embd;
|
||||
uint32_t n_head;
|
||||
uint32_t n_head_kv;
|
||||
static constexpr uint32_t n_layer = 1;
|
||||
uint32_t n_rot;
|
||||
uint32_t n_embd_head; // dimension of values (d_v)
|
||||
uint32_t n_ff;
|
||||
|
||||
GGML_UNUSED(t);
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR1(n_tokens);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 2e-3;
|
||||
}
|
||||
|
||||
test_llama(int n_tokens = 1)
|
||||
: n_tokens(n_tokens) {
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
|
||||
// cparams
|
||||
static constexpr uint32_t n_ctx = 512; // user-specified context size
|
||||
static constexpr uint32_t n_orig_ctx = n_ctx;
|
||||
|
||||
// batch
|
||||
int32_t n_tokens;
|
||||
|
||||
// llm_build_context
|
||||
static constexpr int32_t n_kv = 32; // size of KV cache to consider (n_kv <= n_ctx
|
||||
static constexpr int32_t kv_head = 1; // index of where we store new KV data in the cache
|
||||
|
||||
uint32_t n_embd_gqa() const { // dimension of key embeddings across all k-v heads
|
||||
return n_embd_head * n_head_kv;
|
||||
}
|
||||
};
|
||||
|
||||
// LLM base class
|
||||
struct test_llm : public test_case {
|
||||
llama_hparams hp;
|
||||
|
||||
protected:
|
||||
test_llm(llama_hparams hp)
|
||||
: hp(std::move(hp)) {
|
||||
}
|
||||
|
||||
public:
|
||||
struct ggml_tensor * llm_build_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * cur,
|
||||
struct ggml_tensor * mw) {
|
||||
cur = ggml_rms_norm(ctx, cur, f_norm_rms_eps);
|
||||
struct ggml_tensor * mw,
|
||||
struct ggml_tensor * mb,
|
||||
llm_norm_type type) {
|
||||
switch (type) {
|
||||
case LLM_NORM: cur = ggml_norm (ctx, cur, hp.f_norm_eps); break;
|
||||
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hp.f_norm_rms_eps); break;
|
||||
}
|
||||
cur = ggml_mul(ctx, cur, mw);
|
||||
if (mb) {
|
||||
cur = ggml_add(ctx, cur, mb);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
|
@ -1528,14 +1539,14 @@ struct test_llama : public test_case {
|
|||
struct ggml_tensor * k_cur,
|
||||
struct ggml_tensor * v_cur) {
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
|
||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, hp.n_embd_gqa(), hp.n_tokens));
|
||||
|
||||
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, n_tokens*n_embd_k_gqa,
|
||||
(ggml_row_size(k_l->type, n_embd_k_gqa))*kv_head);
|
||||
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, hp.n_tokens*hp.n_embd_gqa(),
|
||||
(ggml_row_size(k_l->type, hp.n_embd_gqa()))*hp.kv_head);
|
||||
|
||||
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, n_tokens, n_embd_v_gqa,
|
||||
( n_ctx)*ggml_element_size(v_l),
|
||||
(kv_head)*ggml_element_size(v_l));
|
||||
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, hp.n_tokens, hp.n_embd_gqa(),
|
||||
( hp.n_ctx)*ggml_element_size(v_l),
|
||||
(hp.kv_head)*ggml_element_size(v_l));
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_cpy(ctx, k_cur, k_cache_view);
|
||||
|
@ -1554,9 +1565,9 @@ struct test_llama : public test_case {
|
|||
|
||||
struct ggml_tensor * k =
|
||||
ggml_view_3d(ctx, k_l,
|
||||
n_embd_head, n_kv, n_head_kv,
|
||||
ggml_row_size(k_l->type, n_embd_k_gqa),
|
||||
ggml_row_size(k_l->type, n_embd_head),
|
||||
hp.n_embd_head, hp.n_kv, hp.n_head_kv,
|
||||
ggml_row_size(k_l->type, hp.n_embd_gqa()),
|
||||
ggml_row_size(k_l->type, hp.n_embd_head),
|
||||
0);
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
|
@ -1566,52 +1577,105 @@ struct test_llama : public test_case {
|
|||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, v_l,
|
||||
n_kv, n_embd_head, n_head_kv,
|
||||
ggml_element_size(v_l)*n_ctx,
|
||||
ggml_element_size(v_l)*n_ctx*n_embd_head,
|
||||
hp.n_kv, hp.n_embd_head, hp.n_head_kv,
|
||||
ggml_element_size(v_l)*hp.n_ctx,
|
||||
ggml_element_size(v_l)*hp.n_ctx*hp.n_embd_head,
|
||||
0);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head*n_head, n_tokens);
|
||||
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, hp.n_embd_head*hp.n_head, hp.n_tokens);
|
||||
|
||||
struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200);
|
||||
struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
|
||||
cur = ggml_mul_mat(ctx, wo, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t n_rot = n_embd_head;
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
// pos
|
||||
std::vector<int> data(hp.n_tokens);
|
||||
for (int i = 0; i < hp.n_tokens; i++) {
|
||||
data[i] = rand() % hp.n_ctx;
|
||||
}
|
||||
ggml_backend_tensor_set(t, data.data(), 0, hp.n_tokens * sizeof(int));
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Llama
|
||||
struct test_llama : public test_llm {
|
||||
static constexpr float freq_base = 10000.0f;
|
||||
static constexpr float freq_scale = 1.0f;
|
||||
static constexpr float ext_factor = 0.0f;
|
||||
static constexpr float attn_factor = 1.0f;
|
||||
static constexpr float beta_fast = 32.0f;
|
||||
static constexpr float beta_slow = 1.0f;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "LLAMA";
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
auto n_tokens = hp.n_tokens;
|
||||
return VARS_TO_STR1(n_tokens);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 2e-3;
|
||||
}
|
||||
|
||||
test_llama(int n_tokens = 1)
|
||||
: test_llm({
|
||||
/*n_vocab =*/ 32000,
|
||||
/*n_embd =*/ 3200,
|
||||
/*n_head =*/ 32,
|
||||
/*n_head_kv =*/ 32,
|
||||
/*n_rot =*/ 100,
|
||||
/*n_embd_head =*/ 100,
|
||||
/*n_ff =*/ 8640,
|
||||
/*f_norm_eps =*/ 0.f,
|
||||
/*f_norm_rms_eps =*/ 1e-5f,
|
||||
/*n_tokens =*/ n_tokens,
|
||||
}) {
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
|
||||
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens);
|
||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_kv, n_tokens, 1);
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
||||
|
||||
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
for (uint32_t il = 0; il < hp.n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3200);
|
||||
cur = llm_build_norm(ctx, inpL, attn_norm);
|
||||
ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
cur = llm_build_norm(ctx, inpL, attn_norm, nullptr, LLM_NORM_RMS);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200);
|
||||
ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200);
|
||||
ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200);
|
||||
ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
|
||||
ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
|
||||
ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur);
|
||||
|
@ -1619,31 +1683,31 @@ struct test_llama : public test_case {
|
|||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
|
||||
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx, ggml_reshape_3d(ctx, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||
n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos,
|
||||
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx, ggml_reshape_3d(ctx, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||
n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
|
||||
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
|
||||
|
||||
cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(n_embd_head)));
|
||||
cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA);
|
||||
|
||||
// feed-forward network
|
||||
ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3200);
|
||||
cur = llm_build_norm(ctx, ffn_inp, ffn_norm);
|
||||
ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
cur = llm_build_norm(ctx, ffn_inp, ffn_norm, nullptr, LLM_NORM_RMS);
|
||||
|
||||
ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 8640);
|
||||
ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 8640);
|
||||
ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 8640, 3200);
|
||||
ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
|
||||
ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
|
||||
ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur);
|
||||
cur = ggml_mul_mat(ctx, ffn_gate, cur);
|
||||
cur = ggml_silu(ctx, cur);
|
||||
|
@ -1658,29 +1722,138 @@ struct test_llama : public test_case {
|
|||
|
||||
cur = inpL;
|
||||
|
||||
ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3200);
|
||||
cur = llm_build_norm(ctx, cur, output_norm);
|
||||
ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
cur = llm_build_norm(ctx, cur, output_norm, nullptr, LLM_NORM_RMS);
|
||||
|
||||
// lm_head
|
||||
ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 32000);
|
||||
ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_vocab);
|
||||
cur = ggml_mul_mat(ctx, output, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
};
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
// pos
|
||||
std::vector<int> data(n_tokens);
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
data[i] = rand() % n_ctx;
|
||||
}
|
||||
ggml_backend_tensor_set(t, data.data(), 0, n_tokens * sizeof(int));
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
// Falcon
|
||||
struct test_falcon : public test_llm {
|
||||
static constexpr float freq_base = 10000.0f;
|
||||
static constexpr float freq_scale = 1.0f;
|
||||
static constexpr float ext_factor = 0.0f;
|
||||
static constexpr float attn_factor = 1.0f;
|
||||
static constexpr float beta_fast = 32.0f;
|
||||
static constexpr float beta_slow = 1.0f;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "FALCON";
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
auto n_tokens = hp.n_tokens;
|
||||
return VARS_TO_STR1(n_tokens);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 2e-3;
|
||||
}
|
||||
|
||||
test_falcon(int n_tokens = 1)
|
||||
: test_llm({
|
||||
/*n_vocab =*/ 65024,
|
||||
/*n_embd =*/ 4544,
|
||||
/*n_head =*/ 71,
|
||||
/*n_head_kv =*/ 1,
|
||||
/*n_rot =*/ 64,
|
||||
/*n_embd_head =*/ 64,
|
||||
/*n_ff =*/ 18176,
|
||||
/*f_norm_eps =*/ 1e-5f,
|
||||
/*f_norm_rms_eps =*/ 0.f,
|
||||
/*n_tokens =*/ n_tokens,
|
||||
}) {
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
||||
|
||||
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||
|
||||
for (uint32_t il = 0; il < hp.n_layer; ++il) {
|
||||
// norm
|
||||
ggml_tensor * attn_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
ggml_tensor * attn_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
ggml_tensor * attn_norm = llm_build_norm(ctx, inpL, attn_norm_w, attn_norm_b, LLM_NORM);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
cur = attn_norm;
|
||||
|
||||
ggml_tensor * wqkv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd + 2*hp.n_embd_gqa());
|
||||
|
||||
cur = ggml_mul_mat(ctx, wqkv, cur);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd, hp.n_tokens, cur->nb[1], 0*sizeof(float)*(hp.n_embd)));
|
||||
struct ggml_tensor * Kcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd)));
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd + hp.n_embd_gqa())));
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
|
||||
|
||||
cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = cur;
|
||||
|
||||
// feed forward
|
||||
{
|
||||
ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
|
||||
ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
|
||||
cur = attn_norm;
|
||||
cur = ggml_mul_mat(ctx, ffn_up, cur);
|
||||
cur = ggml_gelu(ctx, cur);
|
||||
cur = ggml_mul_mat(ctx, ffn_down, cur);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx, cur, ffn_inp);
|
||||
|
||||
cur = ggml_add(ctx, cur, inpL);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
ggml_tensor * output_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
cur = llm_build_norm(ctx, cur, output_norm, output_norm_b, LLM_NORM);
|
||||
|
||||
// lm_head
|
||||
ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q8_0, hp.n_embd, hp.n_vocab);
|
||||
cur = ggml_mul_mat(ctx, output, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1821,6 +1994,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5));
|
||||
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
|
||||
|
||||
#if 0
|
||||
std::uniform_int_distribution<> dist_ne1(1, 50);
|
||||
int exponent = 1;
|
||||
while (exponent < (1 << 17)) {
|
||||
|
@ -1834,6 +2008,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
|
||||
exponent <<= 1;
|
||||
}
|
||||
#endif
|
||||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true));
|
||||
|
@ -1876,6 +2051,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
|
||||
test_cases.emplace_back(new test_llama(1));
|
||||
test_cases.emplace_back(new test_llama(2));
|
||||
test_cases.emplace_back(new test_falcon(1));
|
||||
test_cases.emplace_back(new test_falcon(2));
|
||||
|
||||
// run tests
|
||||
if (mode == MODE_TEST) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue