Fix FlashAttention debug test, FP32 assert (#7684)

This commit is contained in:
Johannes Gäßler 2024-06-01 23:26:10 +02:00 committed by GitHub
parent 2e666832e6
commit e141ce624a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 7 deletions

View file

@ -278,14 +278,10 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_tensor * KQV = dst;
ggml_tensor * Q = dst->src[0];
ggml_tensor * K = dst->src[1];
ggml_tensor * V = dst->src[2];
const int32_t precision = KQV->op_params[2];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);