llama : avoid ggml_cast, use F32 query

This commit is contained in:
Georgi Gerganov 2024-01-25 17:46:07 +02:00
parent 40ea8cd1ac
commit f9ca5dcbe8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 44 additions and 17 deletions

View file

@ -1386,26 +1386,24 @@ struct test_leaky_relu : public test_case {
// GGML_OP_FLASH_ATTN_EXT
struct test_flash_attn_ext : public test_case {
const ggml_type typeq;
const int64_t hs; // head size
const int64_t nh; // num heads
const int64_t kv; // kv size
const int64_t nb; // batch size
std::string vars() override {
return VARS_TO_STR5(typeq, hs, nh, kv, nb);
return VARS_TO_STR4(hs, nh, kv, nb);
}
double max_nmse_err() override {
return 5e-5;
}
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
: typeq(typeq), hs(hs), nh(nh), kv(kv), nb(nb) {}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
: hs(hs), nh(nh), kv(kv), nb(nb) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, typeq, hs, nb, nh, 1);
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, kv, nb, 1, 1);
@ -1680,9 +1678,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 8));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 7));
test_cases.emplace_back(new test_flash_attn_ext(128, 32, 256, 1));
#if !defined(__SANITIZE_THREAD__)
// FIXME: these tests use too much memory with thread sanitizer