ggml : add ggml_flash_attn_ext API
This commit is contained in:
parent
ad19812cda
commit
a1c004ef2e
6 changed files with 456 additions and 38 deletions
|
@ -1384,6 +1384,32 @@ struct test_leaky_relu : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_FLASH_ATTN_EXT
|
||||
struct test_flash_attn_ext : public test_case {
|
||||
const ggml_type typeq;
|
||||
const int64_t hs; // head size
|
||||
const int64_t nh; // num heads
|
||||
const int64_t kv; // kv size
|
||||
const int64_t nt; // tokens
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(typeq, hs, nh, kv, nt);
|
||||
}
|
||||
|
||||
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
|
||||
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nt = 8)
|
||||
: typeq(typeq), hs(hs), nh(nh), kv(kv), nt(nt) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nt, nh, 1);
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1);
|
||||
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nt, 1, 1);
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// Mixtral MOE
|
||||
struct test_moe : public test_case {
|
||||
const int n_experts;
|
||||
|
@ -1650,6 +1676,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
test_cases.emplace_back(new test_pad());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8));
|
||||
|
||||
#if !defined(__SANITIZE_THREAD__)
|
||||
// FIXME: these tests use too much memory with thread sanitizer
|
||||
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue