Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cpp into flash-attn-cuda

This commit is contained in:
FSSRepo 2024-01-25 09:48:37 -05:00
commit 78da3387a8
3 changed files with 205 additions and 121 deletions

View file

@ -1397,7 +1397,7 @@ struct test_flash_attn_ext : public test_case {
}
double max_nmse_err() override {
return 5e-4;
return 5e-5;
}
test_flash_attn_ext(ggml_type typeq = GGML_TYPE_F16,
@ -1680,7 +1680,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 96, 8));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 8));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 7));
test_cases.emplace_back(new test_flash_attn_ext(GGML_TYPE_F16, 128, 32, 256, 1));
#if !defined(__SANITIZE_THREAD__)
// FIXME: these tests use too much memory with thread sanitizer