Nomic Vulkan backend (#4456)
Signed-off-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: niansa <anton-sa@web.de> Co-authored-by: Adam Treat <treat.adam@gmail.com> Co-authored-by: Aaron Miller <apage43@ninjawhale.com> Co-authored-by: ToKiNoBug <tokinobug@163.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
2aed77eb06
commit
fbf1ddec69
45 changed files with 4271 additions and 19 deletions
|
@ -370,12 +370,15 @@ struct test_case {
|
|||
printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
|
||||
fflush(stdout);
|
||||
|
||||
// check if backends support op
|
||||
// check if the backends support the ops
|
||||
bool supported = true;
|
||||
for (ggml_backend_t backend : {backend1, backend2}) {
|
||||
if (!ggml_backend_supports_op(backend, out)) {
|
||||
printf("not supported [%s] ", ggml_backend_name(backend));
|
||||
supported = false;
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (!ggml_backend_supports_op(backend, t)) {
|
||||
printf("not supported [%s] ", ggml_backend_name(backend));
|
||||
supported = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!supported) {
|
||||
|
@ -626,6 +629,13 @@ struct test_unary : public test_case {
|
|||
ggml_tensor * out = ggml_unary(ctx, in, op);
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
// test extended range of values to check for NaNs in GELU
|
||||
init_tensor_uniform(t, -150.f, 150.f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_GET_ROWS
|
||||
|
@ -1066,18 +1076,24 @@ struct test_diag_mask_inf : public test_case {
|
|||
struct test_soft_max : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float scale;
|
||||
const bool mask;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR2(type, ne);
|
||||
return VARS_TO_STR4(type, ne, scale, mask);
|
||||
}
|
||||
|
||||
test_soft_max(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 10})
|
||||
: type(type), ne(ne) {}
|
||||
std::array<int64_t, 4> ne = {10, 10, 10, 10},
|
||||
float scale = 1.0f,
|
||||
bool mask = false)
|
||||
: type(type), ne(ne), scale(scale), mask(mask) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * out = ggml_soft_max(ctx, a);
|
||||
ggml_tensor * b = nullptr;
|
||||
if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
@ -1474,6 +1490,393 @@ struct test_moe : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
enum llm_norm_type {
|
||||
LLM_NORM,
|
||||
LLM_NORM_RMS,
|
||||
};
|
||||
|
||||
struct llama_hparams {
|
||||
uint32_t n_vocab;
|
||||
uint32_t n_embd;
|
||||
uint32_t n_head;
|
||||
uint32_t n_head_kv;
|
||||
static constexpr uint32_t n_layer = 1;
|
||||
uint32_t n_rot;
|
||||
uint32_t n_embd_head; // dimension of values (d_v)
|
||||
uint32_t n_ff;
|
||||
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
|
||||
// cparams
|
||||
static constexpr uint32_t n_ctx = 512; // user-specified context size
|
||||
static constexpr uint32_t n_orig_ctx = n_ctx;
|
||||
|
||||
// batch
|
||||
int32_t n_tokens;
|
||||
|
||||
// llm_build_context
|
||||
static constexpr int32_t n_kv = 32; // size of KV cache to consider (n_kv <= n_ctx
|
||||
static constexpr int32_t kv_head = 1; // index of where we store new KV data in the cache
|
||||
|
||||
uint32_t n_embd_gqa() const { // dimension of key embeddings across all k-v heads
|
||||
return n_embd_head * n_head_kv;
|
||||
}
|
||||
};
|
||||
|
||||
// LLM base class
|
||||
struct test_llm : public test_case {
|
||||
llama_hparams hp;
|
||||
|
||||
protected:
|
||||
test_llm(llama_hparams hp)
|
||||
: hp(std::move(hp)) {
|
||||
}
|
||||
|
||||
public:
|
||||
struct ggml_tensor * llm_build_norm(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * cur,
|
||||
struct ggml_tensor * mw,
|
||||
struct ggml_tensor * mb,
|
||||
llm_norm_type type) {
|
||||
switch (type) {
|
||||
case LLM_NORM: cur = ggml_norm (ctx, cur, hp.f_norm_eps); break;
|
||||
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hp.f_norm_rms_eps); break;
|
||||
}
|
||||
cur = ggml_mul(ctx, cur, mw);
|
||||
if (mb) {
|
||||
cur = ggml_add(ctx, cur, mb);
|
||||
}
|
||||
return cur;
|
||||
}
|
||||
|
||||
void llm_build_kv_store(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k_l,
|
||||
struct ggml_tensor * v_l,
|
||||
struct ggml_tensor * k_cur,
|
||||
struct ggml_tensor * v_cur) {
|
||||
// compute the transposed [n_tokens, n_embd] V matrix
|
||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, hp.n_embd_gqa(), hp.n_tokens));
|
||||
|
||||
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, hp.n_tokens*hp.n_embd_gqa(),
|
||||
(ggml_row_size(k_l->type, hp.n_embd_gqa()))*hp.kv_head);
|
||||
|
||||
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, hp.n_tokens, hp.n_embd_gqa(),
|
||||
( hp.n_ctx)*ggml_element_size(v_l),
|
||||
(hp.kv_head)*ggml_element_size(v_l));
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_cpy(ctx, k_cur, k_cache_view);
|
||||
ggml_cpy(ctx, v_cur_t, v_cache_view);
|
||||
}
|
||||
|
||||
// if max_alibi_bias > 0 then apply ALiBi
|
||||
struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k_l,
|
||||
struct ggml_tensor * v_l,
|
||||
struct ggml_tensor * q_cur,
|
||||
struct ggml_tensor * kq_mask,
|
||||
float kq_scale) {
|
||||
struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * k =
|
||||
ggml_view_3d(ctx, k_l,
|
||||
hp.n_embd_head, hp.n_kv, hp.n_head_kv,
|
||||
ggml_row_size(k_l->type, hp.n_embd_gqa()),
|
||||
ggml_row_size(k_l->type, hp.n_embd_head),
|
||||
0);
|
||||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
ggml_view_3d(ctx, v_l,
|
||||
hp.n_kv, hp.n_embd_head, hp.n_head_kv,
|
||||
ggml_element_size(v_l)*hp.n_ctx,
|
||||
ggml_element_size(v_l)*hp.n_ctx*hp.n_embd_head,
|
||||
0);
|
||||
|
||||
struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
|
||||
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, hp.n_embd_head*hp.n_head, hp.n_tokens);
|
||||
|
||||
struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
|
||||
cur = ggml_mul_mat(ctx, wo, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
// pos
|
||||
std::vector<int> data(hp.n_tokens);
|
||||
for (int i = 0; i < hp.n_tokens; i++) {
|
||||
data[i] = rand() % hp.n_ctx;
|
||||
}
|
||||
ggml_backend_tensor_set(t, data.data(), 0, hp.n_tokens * sizeof(int));
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Llama
|
||||
struct test_llama : public test_llm {
|
||||
static constexpr float freq_base = 10000.0f;
|
||||
static constexpr float freq_scale = 1.0f;
|
||||
static constexpr float ext_factor = 0.0f;
|
||||
static constexpr float attn_factor = 1.0f;
|
||||
static constexpr float beta_fast = 32.0f;
|
||||
static constexpr float beta_slow = 1.0f;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "LLAMA";
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
auto n_tokens = hp.n_tokens;
|
||||
return VARS_TO_STR1(n_tokens);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 2e-3;
|
||||
}
|
||||
|
||||
test_llama(int n_tokens = 1)
|
||||
: test_llm({
|
||||
/*n_vocab =*/ 32000,
|
||||
/*n_embd =*/ 3200,
|
||||
/*n_head =*/ 32,
|
||||
/*n_head_kv =*/ 32,
|
||||
/*n_rot =*/ 100,
|
||||
/*n_embd_head =*/ 100,
|
||||
/*n_ff =*/ 8640,
|
||||
/*f_norm_eps =*/ 0.f,
|
||||
/*f_norm_rms_eps =*/ 1e-5f,
|
||||
/*n_tokens =*/ n_tokens,
|
||||
}) {
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
||||
|
||||
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||
|
||||
for (uint32_t il = 0; il < hp.n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
cur = llm_build_norm(ctx, inpL, attn_norm, nullptr, LLM_NORM_RMS);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
|
||||
ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
|
||||
ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur);
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
|
||||
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens), inp_pos,
|
||||
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos,
|
||||
hp.n_rot, 0, 0, hp.n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
|
||||
|
||||
cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA);
|
||||
|
||||
// feed-forward network
|
||||
ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
cur = llm_build_norm(ctx, ffn_inp, ffn_norm, nullptr, LLM_NORM_RMS);
|
||||
|
||||
ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
|
||||
ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
|
||||
ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur);
|
||||
cur = ggml_mul_mat(ctx, ffn_gate, cur);
|
||||
cur = ggml_silu(ctx, cur);
|
||||
cur = ggml_mul(ctx, cur, tmp);
|
||||
cur = ggml_mul_mat(ctx, ffn_down, cur);
|
||||
|
||||
cur = ggml_add(ctx, cur, ffn_inp);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
cur = llm_build_norm(ctx, cur, output_norm, nullptr, LLM_NORM_RMS);
|
||||
|
||||
// lm_head
|
||||
ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_vocab);
|
||||
cur = ggml_mul_mat(ctx, output, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
};
|
||||
|
||||
// Falcon
|
||||
struct test_falcon : public test_llm {
|
||||
static constexpr float freq_base = 10000.0f;
|
||||
static constexpr float freq_scale = 1.0f;
|
||||
static constexpr float ext_factor = 0.0f;
|
||||
static constexpr float attn_factor = 1.0f;
|
||||
static constexpr float beta_fast = 32.0f;
|
||||
static constexpr float beta_slow = 1.0f;
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "FALCON";
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
auto n_tokens = hp.n_tokens;
|
||||
return VARS_TO_STR1(n_tokens);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 2e-3;
|
||||
}
|
||||
|
||||
test_falcon(int n_tokens = 1)
|
||||
: test_llm({
|
||||
/*n_vocab =*/ 32000,
|
||||
/*n_embd =*/ 3200,
|
||||
/*n_head =*/ 50,
|
||||
/*n_head_kv =*/ 1,
|
||||
/*n_rot =*/ 64,
|
||||
/*n_embd_head =*/ 64,
|
||||
/*n_ff =*/ 8640,
|
||||
/*f_norm_eps =*/ 1e-5f,
|
||||
/*f_norm_rms_eps =*/ 0.f,
|
||||
/*n_tokens =*/ n_tokens,
|
||||
}) {
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
||||
|
||||
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||
|
||||
for (uint32_t il = 0; il < hp.n_layer; ++il) {
|
||||
// norm
|
||||
ggml_tensor * attn_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
ggml_tensor * attn_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
ggml_tensor * attn_norm = llm_build_norm(ctx, inpL, attn_norm_w, attn_norm_b, LLM_NORM);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
cur = attn_norm;
|
||||
|
||||
ggml_tensor * wqkv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd + 2*hp.n_embd_gqa());
|
||||
|
||||
cur = ggml_mul_mat(ctx, wqkv, cur);
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd, hp.n_tokens, cur->nb[1], 0*sizeof(float)*(hp.n_embd)));
|
||||
struct ggml_tensor * Kcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd)));
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd + hp.n_embd_gqa())));
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head, hp.n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx, Qcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx, Kcur, inp_pos, hp.n_rot, 2, 0, hp.n_orig_ctx,
|
||||
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
|
||||
|
||||
cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = cur;
|
||||
|
||||
// feed forward
|
||||
{
|
||||
ggml_tensor * ffn_up = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
|
||||
ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
|
||||
cur = attn_norm;
|
||||
cur = ggml_mul_mat(ctx, ffn_up, cur);
|
||||
cur = ggml_gelu(ctx, cur);
|
||||
cur = ggml_mul_mat(ctx, ffn_down, cur);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx, cur, ffn_inp);
|
||||
|
||||
cur = ggml_add(ctx, cur, inpL);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
ggml_tensor * output_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
|
||||
cur = llm_build_norm(ctx, cur, output_norm, output_norm_b, LLM_NORM);
|
||||
|
||||
// lm_head
|
||||
ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q8_0, hp.n_embd, hp.n_vocab);
|
||||
cur = ggml_mul_mat(ctx, output, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
};
|
||||
|
||||
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
|
||||
std::vector<std::unique_ptr<test_case>> test_cases;
|
||||
std::default_random_engine rng(0);
|
||||
|
@ -1626,6 +2029,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
exponent <<= 1;
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true));
|
||||
|
||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
|
||||
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512)); // llama 13B
|
||||
|
@ -1662,6 +2068,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
//test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
|
||||
#endif
|
||||
|
||||
// these tests are disabled to save execution time, but they can be handy for debugging
|
||||
#if 0
|
||||
test_cases.emplace_back(new test_llama(1));
|
||||
test_cases.emplace_back(new test_llama(2));
|
||||
test_cases.emplace_back(new test_falcon(1));
|
||||
test_cases.emplace_back(new test_falcon(2));
|
||||
#endif
|
||||
|
||||
// run tests
|
||||
if (mode == MODE_TEST) {
|
||||
ggml_backend_t backend_cpu = ggml_backend_cpu_init();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue