CUDA: add FP32 FlashAttention vector kernel (#7188)

* CUDA: add FP32 FlashAttention vector kernel

* fixup! CUDA: add FP32 FlashAttention vector kernel

* fixup! fixup! CUDA: add FP32 FlashAttention vector kernel

* fixup! fixup! fixup! CUDA: add FP32 FlashAttention vector kernel
This commit is contained in:
Johannes Gäßler 2024-05-12 19:40:45 +02:00 committed by GitHub
parent 6f1b63606f
commit dc685be466
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 899 additions and 458 deletions

View file

@ -0,0 +1,5 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst);