test-backend-ops : add Falcon test

This commit is contained in:
Jared Van Bortel 2024-01-25 13:55:49 -05:00
parent f5ac635473
commit 1849b85473

View file

@ -1471,53 +1471,64 @@ struct test_moe : public test_case {
}; };
// llama enum llm_norm_type {
struct test_llama : public test_case { LLM_NORM,
const int n_tokens; LLM_NORM_RMS,
static constexpr float f_norm_rms_eps = 1e-5; };
static constexpr int64_t n_embd_k_gqa = 3200;
static constexpr int64_t n_embd_v_gqa = 3200;
static constexpr int64_t n_ctx = 512;
static constexpr int64_t n_layer = 1;
static constexpr int64_t n_head = 32;
static constexpr int64_t n_head_kv = 32;
static constexpr int64_t n_embd_head = 100;
static constexpr int64_t n_embd = 3200;
static constexpr int64_t n_orig_ctx = n_ctx;
static constexpr int64_t n_ff = 8640;
static constexpr int64_t n_kv = 32;
static constexpr int64_t kv_head = 1;
static constexpr float freq_base = 10000.0f;
static constexpr float freq_scale = 1.0f;
static constexpr float ext_factor = 0.0f;
static constexpr float attn_factor = 1.0f;
static constexpr float beta_fast = 32.0f;
static constexpr float beta_slow = 1.0f;
std::string op_desc(ggml_tensor * t) override { struct llama_hparams {
return "LLAMA"; uint32_t n_vocab;
uint32_t n_embd;
uint32_t n_head;
uint32_t n_head_kv;
static constexpr uint32_t n_layer = 1;
uint32_t n_rot;
uint32_t n_embd_head; // dimension of values (d_v)
uint32_t n_ff;
GGML_UNUSED(t); float f_norm_eps;
} float f_norm_rms_eps;
std::string vars() override { // cparams
return VARS_TO_STR1(n_tokens); static constexpr uint32_t n_ctx = 512; // user-specified context size
} static constexpr uint32_t n_orig_ctx = n_ctx;
double max_nmse_err() override { // batch
return 2e-3; int32_t n_tokens;
}
// llm_build_context
test_llama(int n_tokens = 1) static constexpr int32_t n_kv = 32; // size of KV cache to consider (n_kv <= n_ctx
: n_tokens(n_tokens) { static constexpr int32_t kv_head = 1; // index of where we store new KV data in the cache
uint32_t n_embd_gqa() const { // dimension of key embeddings across all k-v heads
return n_embd_head * n_head_kv;
}
};
// LLM base class
struct test_llm : public test_case {
llama_hparams hp;
protected:
test_llm(llama_hparams hp)
: hp(std::move(hp)) {
} }
public:
struct ggml_tensor * llm_build_norm( struct ggml_tensor * llm_build_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * cur, struct ggml_tensor * cur,
struct ggml_tensor * mw) { struct ggml_tensor * mw,
cur = ggml_rms_norm(ctx, cur, f_norm_rms_eps); struct ggml_tensor * mb,
llm_norm_type type) {
switch (type) {
case LLM_NORM: cur = ggml_norm (ctx, cur, hp.f_norm_eps); break;
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hp.f_norm_rms_eps); break;
}
cur = ggml_mul(ctx, cur, mw); cur = ggml_mul(ctx, cur, mw);
if (mb) {
cur = ggml_add(ctx, cur, mb);
}
return cur; return cur;
} }
@ -1528,14 +1539,14 @@ struct test_llama : public test_case {
struct ggml_tensor * k_cur, struct ggml_tensor * k_cur,
struct ggml_tensor * v_cur) { struct ggml_tensor * v_cur) {
// compute the transposed [n_tokens, n_embd] V matrix // compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, hp.n_embd_gqa(), hp.n_tokens));
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, n_tokens*n_embd_k_gqa, struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, hp.n_tokens*hp.n_embd_gqa(),
(ggml_row_size(k_l->type, n_embd_k_gqa))*kv_head); (ggml_row_size(k_l->type, hp.n_embd_gqa()))*hp.kv_head);
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, n_tokens, n_embd_v_gqa, struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, hp.n_tokens, hp.n_embd_gqa(),
( n_ctx)*ggml_element_size(v_l), ( hp.n_ctx)*ggml_element_size(v_l),
(kv_head)*ggml_element_size(v_l)); (hp.kv_head)*ggml_element_size(v_l));
// important: storing RoPE-ed version of K in the KV cache! // important: storing RoPE-ed version of K in the KV cache!
ggml_cpy(ctx, k_cur, k_cache_view); ggml_cpy(ctx, k_cur, k_cache_view);
@ -1554,9 +1565,9 @@ struct test_llama : public test_case {
struct ggml_tensor * k = struct ggml_tensor * k =
ggml_view_3d(ctx, k_l, ggml_view_3d(ctx, k_l,
n_embd_head, n_kv, n_head_kv, hp.n_embd_head, hp.n_kv, hp.n_head_kv,
ggml_row_size(k_l->type, n_embd_k_gqa), ggml_row_size(k_l->type, hp.n_embd_gqa()),
ggml_row_size(k_l->type, n_embd_head), ggml_row_size(k_l->type, hp.n_embd_head),
0); 0);
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
@ -1566,52 +1577,105 @@ struct test_llama : public test_case {
// split cached v into n_head heads // split cached v into n_head heads
struct ggml_tensor * v = struct ggml_tensor * v =
ggml_view_3d(ctx, v_l, ggml_view_3d(ctx, v_l,
n_kv, n_embd_head, n_head_kv, hp.n_kv, hp.n_embd_head, hp.n_head_kv,
ggml_element_size(v_l)*n_ctx, ggml_element_size(v_l)*hp.n_ctx,
ggml_element_size(v_l)*n_ctx*n_embd_head, ggml_element_size(v_l)*hp.n_ctx*hp.n_embd_head,
0); 0);
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head*n_head, n_tokens); struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, hp.n_embd_head*hp.n_head, hp.n_tokens);
struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200); struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
cur = ggml_mul_mat(ctx, wo, cur); cur = ggml_mul_mat(ctx, wo, cur);
return cur; return cur;
} }
ggml_tensor * build_graph(ggml_context * ctx) override { void initialize_tensors(ggml_context * ctx) override {
const int64_t n_rot = n_embd_head; for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
// pos
std::vector<int> data(hp.n_tokens);
for (int i = 0; i < hp.n_tokens; i++) {
data[i] = rand() % hp.n_ctx;
}
ggml_backend_tensor_set(t, data.data(), 0, hp.n_tokens * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
}
};
// Llama
struct test_llama : public test_llm {
static constexpr float freq_base = 10000.0f;
static constexpr float freq_scale = 1.0f;
static constexpr float ext_factor = 0.0f;
static constexpr float attn_factor = 1.0f;
static constexpr float beta_fast = 32.0f;
static constexpr float beta_slow = 1.0f;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "LLAMA";
}
std::string vars() override {
auto n_tokens = hp.n_tokens;
return VARS_TO_STR1(n_tokens);
}
double max_nmse_err() override {
return 2e-3;
}
test_llama(int n_tokens = 1)
: test_llm({
/*n_vocab =*/ 32000,
/*n_embd =*/ 3200,
/*n_head =*/ 32,
/*n_head_kv =*/ 32,
/*n_rot =*/ 100,
/*n_embd_head =*/ 100,
/*n_ff =*/ 8640,
/*f_norm_eps =*/ 0.f,
/*f_norm_rms_eps =*/ 1e-5f,
/*n_tokens =*/ n_tokens,
}) {
}
ggml_tensor * build_graph(ggml_context * ctx) override {
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
// inp_pos - contains the positions // inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_tokens); struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_kv, n_tokens, 1); struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
for (int il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < hp.n_layer; ++il) {
struct ggml_tensor * inpSA = inpL; struct ggml_tensor * inpSA = inpL;
// norm // norm
ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3200); ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
cur = llm_build_norm(ctx, inpL, attn_norm); cur = llm_build_norm(ctx, inpL, attn_norm, nullptr, LLM_NORM_RMS);
// self-attention // self-attention
{ {
ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200); ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200); ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 3200); ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
// compute Q and K and RoPE them // compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur); struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur);
@ -1619,31 +1683,31 @@ struct test_llama : public test_case {
struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur); struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
Qcur = ggml_rope_custom( Qcur = ggml_rope_custom(
ctx, ggml_reshape_3d(ctx, Qcur, n_embd_head, n_head, n_tokens), inp_pos, ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos,
n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
Kcur = ggml_rope_custom( Kcur = ggml_rope_custom(
ctx, ggml_reshape_3d(ctx, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale, hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur); llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(n_embd_head))); cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
} }
struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA); struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA);
// feed-forward network // feed-forward network
ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3200); ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
cur = llm_build_norm(ctx, ffn_inp, ffn_norm); cur = llm_build_norm(ctx, ffn_inp, ffn_norm, nullptr, LLM_NORM_RMS);
ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 8640); ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 8640); ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 8640, 3200); ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur); struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur);
cur = ggml_mul_mat(ctx, ffn_gate, cur); cur = ggml_mul_mat(ctx, ffn_gate, cur);
cur = ggml_silu(ctx, cur); cur = ggml_silu(ctx, cur);
@ -1658,29 +1722,138 @@ struct test_llama : public test_case {
cur = inpL; cur = inpL;
ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 3200); ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
cur = llm_build_norm(ctx, cur, output_norm); cur = llm_build_norm(ctx, cur, output_norm, nullptr, LLM_NORM_RMS);
// lm_head // lm_head
ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 3200, 32000); ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_vocab);
cur = ggml_mul_mat(ctx, output, cur); cur = ggml_mul_mat(ctx, output, cur);
return cur; return cur;
} }
};
void initialize_tensors(ggml_context * ctx) override { // Falcon
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { struct test_falcon : public test_llm {
if (t->type == GGML_TYPE_I32) { static constexpr float freq_base = 10000.0f;
// pos static constexpr float freq_scale = 1.0f;
std::vector<int> data(n_tokens); static constexpr float ext_factor = 0.0f;
for (int i = 0; i < n_tokens; i++) { static constexpr float attn_factor = 1.0f;
data[i] = rand() % n_ctx; static constexpr float beta_fast = 32.0f;
} static constexpr float beta_slow = 1.0f;
ggml_backend_tensor_set(t, data.data(), 0, n_tokens * sizeof(int));
} else { std::string op_desc(ggml_tensor * t) override {
init_tensor_uniform(t); GGML_UNUSED(t);
return "FALCON";
}
std::string vars() override {
auto n_tokens = hp.n_tokens;
return VARS_TO_STR1(n_tokens);
}
double max_nmse_err() override {
return 2e-3;
}
test_falcon(int n_tokens = 1)
: test_llm({
/*n_vocab =*/ 65024,
/*n_embd =*/ 4544,
/*n_head =*/ 71,
/*n_head_kv =*/ 1,
/*n_rot =*/ 64,
/*n_embd_head =*/ 64,
/*n_ff =*/ 18176,
/*f_norm_eps =*/ 1e-5f,
/*f_norm_rms_eps =*/ 0.f,
/*n_tokens =*/ n_tokens,
}) {
}
ggml_tensor * build_graph(ggml_context * ctx) override {
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
for (uint32_t il = 0; il < hp.n_layer; ++il) {
// norm
ggml_tensor * attn_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
ggml_tensor * attn_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
ggml_tensor * attn_norm = llm_build_norm(ctx, inpL, attn_norm_w, attn_norm_b, LLM_NORM);
// self-attention
{
cur = attn_norm;
ggml_tensor * wqkv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd + 2*hp.n_embd_gqa());
cur = ggml_mul_mat(ctx, wqkv, cur);
struct ggml_tensor * Qcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd, hp.n_tokens, cur->nb[1], 0*sizeof(float)*(hp.n_embd)));
struct ggml_tensor * Kcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd)));
struct ggml_tensor * Vcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd + hp.n_embd_gqa())));
Qcur = ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens);
Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
// using mode = 2 for neox mode
Qcur = ggml_rope_custom(
ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_custom(
ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
} }
struct ggml_tensor * ffn_inp = cur;
// feed forward
{
ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
cur = attn_norm;
cur = ggml_mul_mat(ctx, ffn_up, cur);
cur = ggml_gelu(ctx, cur);
cur = ggml_mul_mat(ctx, ffn_down, cur);
}
cur = ggml_add(ctx, cur, ffn_inp);
cur = ggml_add(ctx, cur, inpL);
// input for next layer
inpL = cur;
} }
cur = inpL;
ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
ggml_tensor * output_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
cur = llm_build_norm(ctx, cur, output_norm, output_norm_b, LLM_NORM);
// lm_head
ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q8_0, hp.n_embd, hp.n_vocab);
cur = ggml_mul_mat(ctx, output, cur);
return cur;
} }
}; };
@ -1821,6 +1994,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5));
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5)); test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
#if 0
std::uniform_int_distribution<> dist_ne1(1, 50); std::uniform_int_distribution<> dist_ne1(1, 50);
int exponent = 1; int exponent = 1;
while (exponent < (1 << 17)) { while (exponent < (1 << 17)) {
@ -1834,6 +2008,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
exponent <<= 1; exponent <<= 1;
} }
#endif
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true));
@ -1876,6 +2051,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_llama(1)); test_cases.emplace_back(new test_llama(1));
test_cases.emplace_back(new test_llama(2)); test_cases.emplace_back(new test_llama(2));
test_cases.emplace_back(new test_falcon(1));
test_cases.emplace_back(new test_falcon(2));
// run tests // run tests
if (mode == MODE_TEST) { if (mode == MODE_TEST) {