From 672244a88b678e1e2c17c5fb82ae6c75638bf954 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 21 May 2024 19:38:25 +0200 Subject: [PATCH 1/6] CUDA: quantized KV support for FA vec --- CMakeLists.txt | 4 + Makefile | 5 +- README.md | 1 + ggml-cuda/fattn-common.cuh | 495 +++++++++++++++++++++++++++++++++++- ggml-cuda/fattn-tile-f16.cu | 3 + ggml-cuda/fattn-tile-f32.cu | 3 + ggml-cuda/fattn-vec-f16.cu | 248 ++++++++++++------ ggml-cuda/fattn-vec-f32.cu | 186 ++++++++++---- ggml-cuda/fattn.cu | 11 +- ggml-cuda/mmq.cu | 4 +- ggml-cuda/vecdotq.cuh | 8 +- 11 files changed, 826 insertions(+), 142 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c5add8239..743619b11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,7 @@ set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "llama: max. batch size for using peer access") option(LLAMA_CUDA_NO_PEER_COPY "llama: do not use peer to peer copies" OFF) option(LLAMA_CUDA_NO_VMM "llama: do not try to use CUDA VMM" OFF) +option(LLAMA_CUDA_FA_ALL_QUANTS "llama: compile all quants for FlashAttention" OFF) option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) @@ -427,6 +428,9 @@ if (LLAMA_CUDA) if (LLAMA_CUDA_NO_PEER_COPY) add_compile_definitions(GGML_CUDA_NO_PEER_COPY) endif() + if (LLAMA_CUDA_FA_ALL_QUANTS) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) + endif() if (LLAMA_STATIC) if (WIN32) diff --git a/Makefile b/Makefile index 5caf31cdf..94cfdab4e 100644 --- a/Makefile +++ b/Makefile @@ -493,7 +493,10 @@ ifdef LLAMA_CUDA_NO_PEER_COPY endif # LLAMA_CUDA_NO_PEER_COPY ifdef LLAMA_CUDA_CCBIN MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) -endif +endif # LLAMA_CUDA_CCBIN +ifdef LLAMA_CUDA_FA_ALL_QUANTS + MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS +endif # LLAMA_CUDA_FA_ALL_QUANTS ifdef JETSON_EOL_MODULE_DETECT define NVCC_COMPILE diff --git a/README.md b/README.md index 15519c97f..4134f1b8d 100644 --- a/README.md +++ b/README.md @@ -481,6 +481,7 @@ Building the program with BLAS support may lead to some performance improvements | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | + | LLAMA_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. | - #### hipBLAS diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 1dd519bde..89bbc9749 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -1,4 +1,5 @@ #include "common.cuh" +#include "vecdotq.cuh" #include @@ -34,11 +35,463 @@ typedef void (* fattn_kernel_t)( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, const int ne3); +typedef half (*vec_dot_KQ_f16_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +typedef float (*vec_dot_KQ_f32_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + GGML_UNUSED(Q_v); + + half sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_0; + const int shift = k_KQ & (QI8_1/2); + + const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_1; + const int shift = k_KQ & (QI8_1/2); + + const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); + sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; + const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + + sum += (T) (sumid4d8 + m4s8scaled); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_0; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE]; + sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y)); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_1; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; + const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + + const int u = Q_q8[k_KQ_0/WARP_SIZE]; + + const int sumi = __dp4a(v, u, 0); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + + const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE]; + const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); + sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); + } else +#endif // FP16_AVAILABLE + { + const float2 * Q_ds = (const float2 *) Q_ds_v; + + const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi; + const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1; + + sum += (T) (sumid5d8 + m5s8scaled); + } + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { +#if __CUDA_ARCH__ > MIN_CC_DP4A + + const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + GGML_UNUSED(Q_v); + + T sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const int ib = k_KQ / QI8_0; + const int iqs = k_KQ % QI8_0; + + const int v = get_int_from_int8(K_q8_0[ib].qs, iqs); + + T Q_d; + if (std::is_same::value) { + const half2 * Q_ds = (const half2 *) Q_ds_v; + Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]); + } else { + const float2 * Q_ds = (const float2 *) Q_ds_v; + Q_d = Q_ds[k_KQ_0/WARP_SIZE].x; + } + + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d); + } + + return sum; +#else + GGML_UNUSED(K_c); + GGML_UNUSED(Q_v); + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ > MIN_CC_DP4A +} + +template +static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const half2 * K_h2 = (const half2 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + +#if FP16_AVAILABLE + if (std::is_same::value) { + const half2 * Q_h2 = (const half2 *) Q_v; + + half2 sum2 = make_half2(0.0f, 0.0f); + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[k_KQ]; + sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE]; + } + + return __low2half(sum2) + __high2half(sum2); + } +#endif // FP16_AVAILABLE + + const float2 * Q_f2 = (const float2 *) Q_v; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { + const int k_KQ = k_KQ_0 + threadIdx.x; + + const half2 K_ik = K_h2[k_KQ]; + sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x; + sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y; + } + + return sum; +} + +template +static __device__ __forceinline__ void quantize_q8_1_to_shared( + const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { + + float vals[sizeof(int)] = {0.0f}; +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + vals[l] = scale * x[4*threadIdx.x + l]; + } + + float amax = fabsf(vals[0]); + float sum = vals[0]; +#pragma unroll + for (int l = 1; l < sizeof(int); ++l) { + amax = fmaxf(amax, fabsf(vals[l])); + sum += vals[l]; + } +#pragma unroll + for (int mask = QI8_1/2; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32)); + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32); + } + + const float d = amax / 127; + int q32 = 0; + int8_t * q8 = (int8_t *) &q32; + + if (d != 0.0f) { +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + q8[l] = roundf(vals[l] / d); + } + } + + yq32[threadIdx.x] = q32; + if (threadIdx.x % QI8_1 == 0) { + if (std::is_same::value) { + ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum); + } else { + ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum); + } + } +} + +typedef half (*dequantize_1_f16_t)(const void *, const int64_t); +typedef float (*dequantize_1_f32_t)(const void *, const int64_t); + +template +static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int64_t ib = i / QK4_0; + const int iqs = i % (QK4_0/2); + const int shift = (i % QK4_0) / (QK4_0/2); + + const T d = x[ib].d; + const int q0 = x[ib].qs[iqs]; + const int q = ((q0 >> (4*shift)) & 0x0F) - 8; + + return d*((T) q); +} + +template +static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const int64_t ib = i / QK4_1; + const int iqs = i % (QK4_1/2); + const int shift = (i % QK4_1) / (QK4_1/2); + + const half2 dm = x[ib].dm; + const int q0 = x[ib].qs[iqs]; + const int q = ((q0 >> (4*shift)) & 0x0F); + +#if FP16_AVAILABLE + if (std::is_same::value) { + return __low2half(dm)*((half) q) + __high2half(dm); + } +#endif // FP16_AVAILABLE + + return __low2float(dm)*((float) q) + __high2float(dm); +} + +template +static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int64_t ib = i / QK5_0; + const int idq = i % QK5_0; + const int iqs = i % (QK5_0/2); + const int shift = (i % QK5_0) / (QK5_0/2); + + const T d = x[ib].d; + const int ql0 = x[ib].qs[iqs]; + const int qh0 = get_int_from_uint8(x[ib].qh, 0); + const int ql = ((ql0 >> (4*shift)) & 0x0F); + const int qh = ((qh0 >> idq) << 4) & 0x10; + const int q = (ql | qh) - 16; + + return d*((T) q); +} + +template +static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const int64_t ib = i / QK5_1; + const int idq = i % QK5_1; + const int iqs = i % (QK5_1/2); + const int shift = (i % QK5_1) / (QK5_1/2); + + const half2 dm = x[ib].dm; + const int ql0 = x[ib].qs[iqs]; + const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0); + const int ql = ((ql0 >> (4*shift)) & 0x0F); + const int qh = ((qh0 >> idq) << 4) & 0x10; + const int q = (ql | qh); + +#if FP16_AVAILABLE + if (std::is_same::value) { + return __low2half(dm)*((half) q) + __high2half(dm); + } +#endif // FP16_AVAILABLE + + return __low2float(dm)*((float) q) + __high2float(dm); +} + +template +static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const int64_t ib = i / QK8_0; + const int iqs = i % QK8_0; + + const T d = x[ib].d; + const int q = x[ib].qs[iqs]; + + return d*((T) q); +} + +template +static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { + const half * x = (const half *) vx; + + return x[i]; +} + template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) @@ -83,6 +536,45 @@ static __global__ void flash_attn_combine_results( dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator; } +// Aliases for FATTN_VEC_CASE macro: +static constexpr ggml_type ggml_type_q4_0 = GGML_TYPE_Q4_0; +static constexpr ggml_type ggml_type_q4_1 = GGML_TYPE_Q4_1; +static constexpr ggml_type ggml_type_q5_0 = GGML_TYPE_Q5_0; +static constexpr ggml_type ggml_type_q5_1 = GGML_TYPE_Q5_1; +static constexpr ggml_type ggml_type_q8_0 = GGML_TYPE_Q8_0; +static constexpr ggml_type ggml_type_f16 = GGML_TYPE_F16; + +typedef half f16; +typedef float f32; + +#define FATTN_VEC_CASE(type_VKQ, D, type_suffix_K, type_suffix_V) \ + if (Q->ne[0] == (D) && K->type == ggml_type_##type_suffix_K && V->type == ggml_type_##type_suffix_V) { \ + constexpr int nwarps = (D)/WARP_SIZE; \ + constexpr bool Q_q8_1 = ggml_type_##type_suffix_K != GGML_TYPE_F16; \ + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_##type_VKQ< \ + (D), cols_per_block, parallel_blocks, \ + vec_dot_fattn_vec_KQ_##type_suffix_K, Q_q8_1, dequantize_1_##type_suffix_V>; \ + launch_fattn<(D), parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block); \ + return; \ + } \ + +static void on_no_fattn_vec_case(const int D) { + if (D == 64) { + fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); + fprintf(stderr, "By default only f16 KV cache is supported.\n"); + fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); + GGML_ASSERT(false); + } else { + fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); + fprintf(stderr, "Supported combinations:\n"); + fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); + fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); + fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); + fprintf(stderr, "Compile with LLAMA_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); + GGML_ASSERT(false); + } +} + template void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) { const ggml_tensor * Q = dst->src[0]; @@ -94,8 +586,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_tensor * KQV = dst; GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(K->type == GGML_TYPE_F16); - GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); @@ -143,6 +633,7 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, Q->nb[1], Q->nb[2], Q->nb[3], K->nb[1], K->nb[2], K->nb[3], + V->nb[1], V->nb[2], V->nb[3], KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] ); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index cdb5eaff7..3d64a9eba 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f16( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, diff --git a/ggml-cuda/fattn-tile-f32.cu b/ggml-cuda/fattn-tile-f32.cu index 5a3de2918..61fce0a7e 100644 --- a/ggml-cuda/fattn-tile-f32.cu +++ b/ggml-cuda/fattn-tile-f32.cu @@ -36,6 +36,9 @@ static __global__ void flash_attn_tile_ext_f32( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index 808e8f362..fa6750364 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -2,7 +2,7 @@ #include "fattn-common.cuh" #include "fattn-vec-f16.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -34,6 +34,9 @@ static __global__ void flash_attn_vec_ext_f16( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, @@ -45,13 +48,11 @@ static __global__ void flash_attn_vec_ext_f16( const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; + Q += nb02* blockIdx.y + nb01*ic0; + K += nb12*(blockIdx.y / gqa_ratio); + V += nb22*(blockIdx.y / gqa_ratio); - const int stride_KV = nb11 / sizeof(half); - const int stride_KV2 = nb11 / sizeof(half2); + const half * maskh = (const half *) mask + ne11*ic0; const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); const half slopeh = __float2half(slopef); @@ -62,10 +63,6 @@ static __global__ void flash_attn_vec_ext_f16( __builtin_assume(tid < D); __shared__ half KQ[ncols*D]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; - } half2 * KQ2 = (half2 *) KQ; half kqmax[ncols]; @@ -86,17 +83,76 @@ static __global__ void flash_attn_vec_ext_f16( } __syncthreads(); - // Convert Q to half2 and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; + // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: + half2 Q_h2[ncols][D/(2*WARP_SIZE)]; + int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; + half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + if (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 2 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tmp_q_i32[i] = 0; + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); + } + continue; + } + + const float * Q_f = (const float *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; + Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } + } + } + + #pragma unroll for (int j = 0; j < ncols; ++j) { -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); - Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } + KQ[j*D + tid] = -HALF_MAX_HALF; } half2 VKQ[ncols] = {{0.0f, 0.0f}}; @@ -123,22 +179,10 @@ static __global__ void flash_attn_vec_ext_f16( break; } - half2 sum2[ncols] = {{0.0f, 0.0f}}; -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE]; - } - } - #pragma unroll for (int j = 0; j < ncols; ++j) { - sum2[j] = warp_reduce_sum(sum2[j]); - half sum = __low2half(sum2[j]) + __high2half(sum2[j]); + half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); if (ncols == 1) { @@ -189,8 +233,8 @@ static __global__ void flash_attn_vec_ext_f16( } half2 V_k; - reinterpret_cast(V_k.x) = V_h[(k_VKQ_0 + k0 + 0)*stride_KV + tid]; - reinterpret_cast(V_k.y) = V_h[(k_VKQ_0 + k0 + 1)*stride_KV + tid]; + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; @@ -248,19 +292,22 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens case 64: { constexpr int D = 64; constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16< + D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16, false, dequantize_1_f16>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16< + D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16, false, dequantize_1_f16>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 256: { constexpr int D = 256; constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16< + D, cols_per_block, parallel_blocks, vec_dot_fattn_vec_KQ_f16, false, dequantize_1_f16>; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: @@ -272,57 +319,100 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens template void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - } break; - default: { - GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + +#ifdef GGML_CUDA_FA_ALL_QUANTS + FATTN_VEC_CASE(f16, 64, f16, q4_0) + FATTN_VEC_CASE(f16, 64, f16, q4_1) + FATTN_VEC_CASE(f16, 64, f16, q5_0) + FATTN_VEC_CASE(f16, 64, f16, q5_1) + FATTN_VEC_CASE(f16, 64, f16, q8_0) + FATTN_VEC_CASE(f16, 64, f16, f16) + + FATTN_VEC_CASE(f16, 128, q4_0, q4_0) + FATTN_VEC_CASE(f16, 128, q4_0, q4_1) + FATTN_VEC_CASE(f16, 128, q4_0, q5_0) + FATTN_VEC_CASE(f16, 128, q4_0, q5_1) + FATTN_VEC_CASE(f16, 128, q4_0, q8_0) + FATTN_VEC_CASE(f16, 128, q4_0, f16) + + FATTN_VEC_CASE(f16, 128, q4_1, q4_0) + FATTN_VEC_CASE(f16, 128, q4_1, q4_1) + FATTN_VEC_CASE(f16, 128, q4_1, q5_0) + FATTN_VEC_CASE(f16, 128, q4_1, q5_1) + FATTN_VEC_CASE(f16, 128, q4_1, q8_0) + FATTN_VEC_CASE(f16, 128, q4_1, f16) + + FATTN_VEC_CASE(f16, 128, q5_0, q4_0) + FATTN_VEC_CASE(f16, 128, q5_0, q4_1) + FATTN_VEC_CASE(f16, 128, q5_0, q5_0) + FATTN_VEC_CASE(f16, 128, q5_0, q5_1) + FATTN_VEC_CASE(f16, 128, q5_0, q8_0) + FATTN_VEC_CASE(f16, 128, q5_0, f16) + + FATTN_VEC_CASE(f16, 128, q5_1, q4_0) + FATTN_VEC_CASE(f16, 128, q5_1, q4_1) + FATTN_VEC_CASE(f16, 128, q5_1, q5_0) + FATTN_VEC_CASE(f16, 128, q5_1, q5_1) + FATTN_VEC_CASE(f16, 128, q5_1, q8_0) + FATTN_VEC_CASE(f16, 128, q5_1, f16) + + FATTN_VEC_CASE(f16, 128, q8_0, q4_0) + FATTN_VEC_CASE(f16, 128, q8_0, q4_1) + FATTN_VEC_CASE(f16, 128, q8_0, q5_0) + FATTN_VEC_CASE(f16, 128, q8_0, q5_1) + FATTN_VEC_CASE(f16, 128, q8_0, q8_0) + FATTN_VEC_CASE(f16, 128, q8_0, f16) + + FATTN_VEC_CASE(f16, 128, f16, q4_0) + FATTN_VEC_CASE(f16, 128, f16, q4_1) + FATTN_VEC_CASE(f16, 128, f16, q5_0) + FATTN_VEC_CASE(f16, 128, f16, q5_1) + FATTN_VEC_CASE(f16, 128, f16, q8_0) + FATTN_VEC_CASE(f16, 128, f16, f16) +#else + FATTN_VEC_CASE(f16, 128, q4_0, q4_0) + FATTN_VEC_CASE(f16, 128, q8_0, q8_0) + FATTN_VEC_CASE(f16, 128, f16, f16) +#endif // GGML_CUDA_FA_ALL_QUANTS + + on_no_fattn_vec_case(Q->ne[0]); } void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; const int32_t precision = KQV->op_params[2]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); - if (Q->ne[1] == 1) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - return; - } + // if (Q->ne[1] == 1) { + // constexpr int cols_per_block = 1; + // constexpr int parallel_blocks = 4; + // launch_fattn_vec_f16_64_128(ctx, dst); + // return; + // } - if (Q->ne[1] == 2) { - constexpr int cols_per_block = 2; - constexpr int parallel_blocks = 4; - launch_fattn_vec_f16_64_128(ctx, dst); - return; - } + // if (Q->ne[1] == 2) { + // constexpr int cols_per_block = 2; + // constexpr int parallel_blocks = 4; + // launch_fattn_vec_f16_64_128(ctx, dst); + // return; + // } - if (Q->ne[1] <= 4) { - constexpr int cols_per_block = 4; - constexpr int parallel_blocks = 4; - launch_fattn_vec_f16_64_128(ctx, dst); - return; - } + // if (Q->ne[1] <= 4) { + // constexpr int cols_per_block = 4; + // constexpr int parallel_blocks = 4; + // launch_fattn_vec_f16_64_128(ctx, dst); + // return; + // } - if (Q->ne[1] <= 8) { - constexpr int cols_per_block = 8; - constexpr int parallel_blocks = 4; - launch_fattn_vec_f16_64_128(ctx, dst); - return; - } + // if (Q->ne[1] <= 8) { + // constexpr int cols_per_block = 8; + // constexpr int parallel_blocks = 4; + // launch_fattn_vec_f16_64_128(ctx, dst); + // return; + // } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu index b4652301b..dded24320 100644 --- a/ggml-cuda/fattn-vec-f32.cu +++ b/ggml-cuda/fattn-vec-f32.cu @@ -2,7 +2,7 @@ #include "fattn-common.cuh" #include "fattn-vec-f32.cuh" -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -34,6 +34,9 @@ static __global__ void flash_attn_vec_ext_f32( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, @@ -44,13 +47,10 @@ static __global__ void flash_attn_vec_ext_f32( const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); - const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) mask + ne11*ic0; - - const int stride_KV = nb11 / sizeof(half); - const int stride_KV2 = nb11 / sizeof(half2); + Q += nb02* blockIdx.y + nb01*ic0; + K += nb12*(blockIdx.y / gqa_ratio); + V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape + const half * maskh = (const half *) mask + ne11*ic0; const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); @@ -83,17 +83,69 @@ static __global__ void flash_attn_vec_ext_f32( } __syncthreads(); - // Convert Q to half2 and store in registers: - float2 Q_h2[ncols][D/(2*WARP_SIZE)]; + // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: + float2 Q_f2[ncols][D/(2*WARP_SIZE)]; + int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)]; + float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + if (Q_q8_1) { #pragma unroll - for (int j = 0; j < ncols; ++j) { -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; - Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); - Q_h2[j][i0/WARP_SIZE].x *= scale; - Q_h2[j][i0/WARP_SIZE].y *= scale; + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 2 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tmp_q_i32[i] = 0; + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); + } + continue; + } + + const float * Q_f = (const float *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; + Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_f2[j][i0/WARP_SIZE].x *= scale; + Q_f2[j][i0/WARP_SIZE].y *= scale; + } } } @@ -117,28 +169,16 @@ static __global__ void flash_attn_vec_ext_f32( break; } - float sum[ncols] = {0.0f}; -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - sum[j] += __low2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].x; - sum[j] += __high2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].y; - } - } - #pragma unroll for (int j = 0; j < ncols; ++j) { - sum[j] = warp_reduce_sum(sum[j]); - sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]); + kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum[j]; + KQ[j*D + i_KQ] = sum; } } } @@ -178,7 +218,7 @@ static __global__ void flash_attn_vec_ext_f32( break; } - const float V_ki = __half2float(V_h[(k_VKQ_0 + k)*stride_KV + tid]); + const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_ki*KQ[j*D + k]; @@ -223,23 +263,65 @@ static __global__ void flash_attn_vec_ext_f32( template void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = D/WARP_SIZE; - fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32; - launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); - } break; - default: { - GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + +#ifdef GGML_CUDA_FA_ALL_QUANTS + FATTN_VEC_CASE(f32, 64, f16, q4_0) + FATTN_VEC_CASE(f32, 64, f16, q4_1) + FATTN_VEC_CASE(f32, 64, f16, q5_0) + FATTN_VEC_CASE(f32, 64, f16, q5_1) + FATTN_VEC_CASE(f32, 64, f16, q8_0) + FATTN_VEC_CASE(f32, 64, f16, f16) + + FATTN_VEC_CASE(f32, 128, q4_0, q4_0) + FATTN_VEC_CASE(f32, 128, q4_0, q4_1) + FATTN_VEC_CASE(f32, 128, q4_0, q5_0) + FATTN_VEC_CASE(f32, 128, q4_0, q5_1) + FATTN_VEC_CASE(f32, 128, q4_0, q8_0) + FATTN_VEC_CASE(f32, 128, q4_0, f16) + + FATTN_VEC_CASE(f32, 128, q4_1, q4_0) + FATTN_VEC_CASE(f32, 128, q4_1, q4_1) + FATTN_VEC_CASE(f32, 128, q4_1, q5_0) + FATTN_VEC_CASE(f32, 128, q4_1, q5_1) + FATTN_VEC_CASE(f32, 128, q4_1, q8_0) + FATTN_VEC_CASE(f32, 128, q4_1, f16) + + FATTN_VEC_CASE(f32, 128, q5_0, q4_0) + FATTN_VEC_CASE(f32, 128, q5_0, q4_1) + FATTN_VEC_CASE(f32, 128, q5_0, q5_0) + FATTN_VEC_CASE(f32, 128, q5_0, q5_1) + FATTN_VEC_CASE(f32, 128, q5_0, q8_0) + FATTN_VEC_CASE(f32, 128, q5_0, f16) + + FATTN_VEC_CASE(f32, 128, q5_1, q4_0) + FATTN_VEC_CASE(f32, 128, q5_1, q4_1) + FATTN_VEC_CASE(f32, 128, q5_1, q5_0) + FATTN_VEC_CASE(f32, 128, q5_1, q5_1) + FATTN_VEC_CASE(f32, 128, q5_1, q8_0) + FATTN_VEC_CASE(f32, 128, q5_1, f16) + + FATTN_VEC_CASE(f32, 128, q8_0, q4_0) + FATTN_VEC_CASE(f32, 128, q8_0, q4_1) + FATTN_VEC_CASE(f32, 128, q8_0, q5_0) + FATTN_VEC_CASE(f32, 128, q8_0, q5_1) + FATTN_VEC_CASE(f32, 128, q8_0, q8_0) + FATTN_VEC_CASE(f32, 128, q8_0, f16) + + FATTN_VEC_CASE(f32, 128, f16, q4_0) + FATTN_VEC_CASE(f32, 128, f16, q4_1) + FATTN_VEC_CASE(f32, 128, f16, q5_0) + FATTN_VEC_CASE(f32, 128, f16, q5_1) + FATTN_VEC_CASE(f32, 128, f16, q8_0) + FATTN_VEC_CASE(f32, 128, f16, f16) +#else + FATTN_VEC_CASE(f32, 128, q4_0, q4_0) + FATTN_VEC_CASE(f32, 128, q8_0, q8_0) + FATTN_VEC_CASE(f32, 128, f16, f16) +#endif // GGML_CUDA_FA_ALL_QUANTS + + on_no_fattn_vec_case(Q->ne[0]); } void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index af7c95232..81b50eb58 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -45,6 +45,9 @@ static __global__ void flash_attn_ext_f16( const int nb11, const int nb12, const int nb13, + const int nb21, + const int nb22, + const int nb23, const int ne0, const int ne1, const int ne2, @@ -457,14 +460,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[2]; + const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type); + // On AMD the tile kernels perform poorly, use the vec kernel instead: - if (cc >= CC_OFFSET_AMD) { - if (precision == GGML_PREC_DEFAULT) { + if (cc >= CC_OFFSET_AMD || quantized_KV) { + if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { ggml_cuda_flash_attn_ext_vec_f16_no_mma(ctx, dst); } else { ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index c0a66d9b6..ebe1dc5c8 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -386,7 +386,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; } - return vec_dot_q8_0_q8_1_impl + return vec_dot_q8_0_q8_1_impl (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); } @@ -547,7 +547,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; - return vec_dot_q8_0_q8_1_impl + return vec_dot_q8_0_q8_1_impl (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); } diff --git a/ggml-cuda/vecdotq.cuh b/ggml-cuda/vecdotq.cuh index 5ebdddcc7..df9752390 100644 --- a/ggml-cuda/vecdotq.cuh +++ b/ggml-cuda/vecdotq.cuh @@ -180,8 +180,8 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp #define VDR_Q8_0_Q8_1_MMVQ 2 #define VDR_Q8_0_Q8_1_MMQ 8 -template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( - const int * v, const int * u, const float & d8_0, const float & d8_1) { +template static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const T & d8_0, const T & d8_1) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; @@ -192,7 +192,7 @@ template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp sumi = __dp4a(v[i], u[i], sumi); } - return d8_0*d8_1 * sumi; + return d8_0*d8_1 * ((T) sumi); #else NO_DEVICE_CODE; #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -656,7 +656,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); } - return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, __low2half(bq8_1->ds)); } static __device__ __forceinline__ float vec_dot_q2_K_q8_1( From 462add6a01d77526a35899cfd3f6ffa02575e5fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 25 May 2024 22:06:25 +0200 Subject: [PATCH 2/6] try CI fix --- ggml-cuda/fattn-common.cuh | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 89bbc9749..b9703dac2 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -404,7 +404,13 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F) - 8; - return d*((T) q); +#if FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); } template @@ -444,7 +450,13 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh) - 16; - return d*((T) q); +#if FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); } template @@ -482,7 +494,13 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ const T d = x[ib].d; const int q = x[ib].qs[iqs]; - return d*((T) q); +#if FP16_AVAILABLE + if (std::is_same::value) { + return ((half) d)*((half) q); + } +#endif // FP16_AVAILABLE + + return ((float) d)*((float) q); } template From 3194a0105868757412ed010b7878727456e6b920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 26 May 2024 20:14:55 +0200 Subject: [PATCH 3/6] fix commented-out kernel variants --- ggml-cuda/fattn-vec-f16.cu | 49 +++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index fa6750364..71581509c 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -382,37 +382,38 @@ void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; const int32_t precision = KQV->op_params[2]; GGML_ASSERT(precision == GGML_PREC_DEFAULT); - // if (Q->ne[1] == 1) { - // constexpr int cols_per_block = 1; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } - // if (Q->ne[1] == 2) { - // constexpr int cols_per_block = 2; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] == 2) { + constexpr int cols_per_block = 2; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } - // if (Q->ne[1] <= 4) { - // constexpr int cols_per_block = 4; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] <= 4) { + constexpr int cols_per_block = 4; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } - // if (Q->ne[1] <= 8) { - // constexpr int cols_per_block = 8; - // constexpr int parallel_blocks = 4; - // launch_fattn_vec_f16_64_128(ctx, dst); - // return; - // } + if (Q->ne[1] <= 8) { + constexpr int cols_per_block = 8; + constexpr int parallel_blocks = 4; + launch_fattn_vec_f16_64_128(ctx, dst); + return; + } constexpr int cols_per_block = 8; constexpr int parallel_blocks = 1; From f08776041dd97841e1c3f446ca8dd523a44dc791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 26 May 2024 22:30:46 +0200 Subject: [PATCH 4/6] add q8_0 q4_0 tests --- ggml-cuda.cu | 8 ++++++-- tests/test-backend-ops.cpp | 16 ++++++++++------ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2a90ee55c..a6e9a991b 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2889,10 +2889,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128; #else - if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) { + if (op->src[0]->ne[0] == 128) { return true; } - return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA; + if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { + return true; + } + return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && + op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) default: return false; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index de74585da..8305e4743 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1540,21 +1540,23 @@ struct test_flash_attn_ext : public test_case { const float max_bias; // ALiBi + const ggml_type type_KV; + std::string vars() override { - return VARS_TO_STR6(hs, nh, kv, nb, mask, max_bias); + return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV); } double max_nmse_err() override { return 5e-4; } - test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f) - : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias) {} + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16) + : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); - ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs, kv, nh, 1); ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr; ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias); return out; @@ -2238,7 +2240,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (int nh : { 32, }) { for (int kv : { 512, 1024, }) { for (int nb : { 1, 2, 4, 8, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias)); + for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV)); + } } } } From f4003cfba1d95511ae59820ff5ea1b62e3cce225 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 26 May 2024 23:00:15 +0200 Subject: [PATCH 5/6] fix nwarps > batch size --- ggml-cuda/fattn-vec-f16.cu | 4 ++++ ggml-cuda/fattn-vec-f32.cu | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index 71581509c..c427e18ab 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f16( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + // Reuse KQ as temporary storage for converting Q to q8_1: int * tmp_q_i32 = (int *) &KQ[j*D]; half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); diff --git a/ggml-cuda/fattn-vec-f32.cu b/ggml-cuda/fattn-vec-f32.cu index dded24320..1b6197d05 100644 --- a/ggml-cuda/fattn-vec-f32.cu +++ b/ggml-cuda/fattn-vec-f32.cu @@ -92,6 +92,10 @@ static __global__ void flash_attn_vec_ext_f32( for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + // Reuse KQ as temporary storage for converting Q to q8_1: int * tmp_q_i32 = (int *) &KQ[j*D]; float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); From 1ca802a3e0314b977b4e53ad0dd729fc2f9f487c Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 28 May 2024 01:19:36 +0200 Subject: [PATCH 6/6] parallelize fattn compilation test --- Makefile | 2 +- ggml-cuda/fattn-common.cuh | 378 ++++++++++++++++++++++++++- ggml-cuda/fattn-vec-f16-f16.cu | 7 + ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu | 5 + ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu | 5 + ggml-cuda/fattn-vec-f16.cu | 281 -------------------- 6 files changed, 390 insertions(+), 288 deletions(-) create mode 100644 ggml-cuda/fattn-vec-f16-f16.cu create mode 100644 ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu create mode 100644 ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu diff --git a/Makefile b/Makefile index 94cfdab4e..278a06d2f 100644 --- a/Makefile +++ b/Makefile @@ -508,7 +508,7 @@ define NVCC_COMPILE endef # NVCC_COMPILE endif # JETSON_EOL_MODULE_DETECT -ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh +ggml-cuda/%.o: ggml-cuda/%.cu ggml.h ggml-common.h ggml-cuda/common.cuh $(NVCC_COMPILE) ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh) diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index b9703dac2..ab840be4d 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -49,7 +49,7 @@ typedef float (*vec_dot_KQ_f32_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( +__device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { #if __CUDA_ARCH__ > MIN_CC_DP4A @@ -263,7 +263,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( } template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( +__device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { #if __CUDA_ARCH__ > MIN_CC_DP4A @@ -304,7 +304,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( } template -static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( +__device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { const half2 * K_h2 = (const half2 *) K_c; @@ -393,7 +393,7 @@ typedef half (*dequantize_1_f16_t)(const void *, const int64_t); typedef float (*dequantize_1_f32_t)(const void *, const int64_t); template -static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { +__device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { const block_q4_0 * x = (const block_q4_0 *) vx; const int64_t ib = i / QK4_0; @@ -485,7 +485,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ } template -static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { +__device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { const block_q8_0 * x = (const block_q8_0 *) vx; const int64_t ib = i / QK8_0; @@ -504,7 +504,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ } template -static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { +__device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { const half * x = (const half *) vx; return x[i]; @@ -669,3 +669,369 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); CUDA_CHECK(cudaGetLastError()); } + +template // D == head size +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(D, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__global__ void flash_attn_vec_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const int ne00, + const int ne01, + const int ne02, + const int ne03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const int ne31, + const int nb31, + const int nb01, + const int nb02, + const int nb03, + const int nb11, + const int nb12, + const int nb13, + const int nb21, + const int nb22, + const int nb23, + const int ne0, + const int ne1, + const int ne2, + const int ne3) { +#if FP16_AVAILABLE + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. + const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb02* blockIdx.y + nb01*ic0; + K += nb12*(blockIdx.y / gqa_ratio); + V += nb22*(blockIdx.y / gqa_ratio); + + const half * maskh = (const half *) mask + ne11*ic0; + + const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = D / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); + + __shared__ half KQ[ncols*D]; + half2 * KQ2 = (half2 *) KQ; + + half kqmax[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax[j] = -HALF_MAX_HALF; + } + half kqsum[ncols] = {0.0f}; + + __shared__ half kqmax_shared[ncols][WARP_SIZE]; + __shared__ half kqsum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; + kqsum_shared[j][threadIdx.x] = 0.0f; + } + } + __syncthreads(); + + // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: + half2 Q_h2[ncols][D/(2*WARP_SIZE)]; + int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; + half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; + if (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 2 && ic0 + j >= ne01) { +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tmp_q_i32[i] = 0; + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); + } + continue; + } + + const float * Q_f = (const float *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; + Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); + } + } + } + + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ[j*D + tid] = -HALF_MAX_HALF; + } + + half2 VKQ[ncols] = {{0.0f, 0.0f}}; + + const int k_start = parallel_blocks == 1 ? 0 : ip*D; + for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { + // Calculate KQ tile and keep track of new maximum KQ values: + + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, + // see https://github.com/ggerganov/llama.cpp/pull/7061 . + // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). + half kqmax_new = kqmax[0]; + half kqmax_new_arr[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax_new_arr[j] = kqmax[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + + if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { + break; + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + + if (ncols == 1) { + kqmax_new = ggml_cuda_hmax(kqmax_new, sum); + } else { + kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); + } + + if (threadIdx.x == 0) { + KQ[j*D + i_KQ] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; + + kqmax_new_j = warp_reduce_max(kqmax_new_j); + if (threadIdx.x == 0) { + kqmax_shared[j][threadIdx.y] = kqmax_new_j; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + half kqmax_new_j = kqmax_shared[j][threadIdx.x]; + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); + kqmax[j] = kqmax_new_j; + + const half val = hexp(KQ[j*D + tid] - kqmax[j]); + kqsum[j] = kqsum[j]*KQ_max_scale + val; + KQ[j*D + tid] = val; + + VKQ[j] *= __half2half2(KQ_max_scale); + } + + __syncthreads(); + +#pragma unroll + for (int k0 = 0; k0 < D; k0 += 2) { + if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { + break; + } + + half2 V_k; + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; + } + } + + __syncthreads(); + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqsum[j] = warp_reduce_sum(kqsum[j]); + if (threadIdx.x == 0) { + kqsum_shared[j][threadIdx.y] = kqsum[j]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 2 && ic0 + j_VKQ >= ne01) { + break; + } + + kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; + kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); + + half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); + if (parallel_blocks == 1) { + dst_val /= kqsum[j_VKQ]; + } + const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; + dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; + } + + if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { + dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); + } +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE +} + +#define DECL_FATTN_VEC_F16_INST(D, ncols, parallel_blocks, vec_dot_KQ, Q_q8_1, dequantize_1_v) \ + template __global__ void flash_attn_vec_ext_f16( \ + const char * __restrict__ Q, \ + const char * __restrict__ K, \ + const char * __restrict__ V, \ + const char * __restrict__ mask, \ + float * __restrict__ dst, \ + float2 * __restrict__ dst_meta, \ + const float scale, \ + const float max_bias, \ + const float m0, \ + const float m1, \ + const uint32_t n_head_log2, \ + const int ne00, \ + const int ne01, \ + const int ne02, \ + const int ne03, \ + const int ne10, \ + const int ne11, \ + const int ne12, \ + const int ne13, \ + const int ne31, \ + const int nb31, \ + const int nb01, \ + const int nb02, \ + const int nb03, \ + const int nb11, \ + const int nb12, \ + const int nb13, \ + const int nb21, \ + const int nb22, \ + const int nb23, \ + const int ne0, \ + const int ne1, \ + const int ne2, \ + const int ne3) + + +extern DECL_FATTN_VEC_F16_INST(64, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); +extern DECL_FATTN_VEC_F16_INST(128, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); +extern DECL_FATTN_VEC_F16_INST(256, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); + +#define DECL_FATTN_VEC_INST(type_VKQ, D, cols_per_block, parallel_blocks, type_suffix_K, type_suffix_V) \ + template __global__ void flash_attn_vec_ext_##type_VKQ< \ + (D), cols_per_block, parallel_blocks, \ + vec_dot_fattn_vec_KQ_##type_suffix_K, ggml_type_##type_suffix_K != GGML_TYPE_F16, dequantize_1_##type_suffix_V>( \ + const char * __restrict__ Q, \ + const char * __restrict__ K, \ + const char * __restrict__ V, \ + const char * __restrict__ mask, \ + float * __restrict__ dst, \ + float2 * __restrict__ dst_meta, \ + const float scale, \ + const float max_bias, \ + const float m0, \ + const float m1, \ + const uint32_t n_head_log2, \ + const int ne00, \ + const int ne01, \ + const int ne02, \ + const int ne03, \ + const int ne10, \ + const int ne11, \ + const int ne12, \ + const int ne13, \ + const int ne31, \ + const int nb31, \ + const int nb01, \ + const int nb02, \ + const int nb03, \ + const int nb11, \ + const int nb12, \ + const int nb13, \ + const int nb21, \ + const int nb22, \ + const int nb23, \ + const int ne0, \ + const int ne1, \ + const int ne2, \ + const int ne3) + + +extern DECL_FATTN_VEC_INST(f16, 128, 1, 4, q4_0, q4_0); +extern DECL_FATTN_VEC_INST(f16, 128, 1, 4, q8_0, q8_0); diff --git a/ggml-cuda/fattn-vec-f16-f16.cu b/ggml-cuda/fattn-vec-f16-f16.cu new file mode 100644 index 000000000..557b561fb --- /dev/null +++ b/ggml-cuda/fattn-vec-f16-f16.cu @@ -0,0 +1,7 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-vec-f16.cuh" + +DECL_FATTN_VEC_F16_INST(64, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); +DECL_FATTN_VEC_F16_INST(128, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); +DECL_FATTN_VEC_F16_INST(256, 1, 4, (vec_dot_fattn_vec_KQ_f16), false, dequantize_1_f16); diff --git a/ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu b/ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu new file mode 100644 index 000000000..ea1b8ba2b --- /dev/null +++ b/ggml-cuda/fattn-vec-f16-q4_0-q4_0.cu @@ -0,0 +1,5 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-vec-f16.cuh" + +DECL_FATTN_VEC_INST(f16, 128, 1, 4, q4_0, q4_0); diff --git a/ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu b/ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu new file mode 100644 index 000000000..773575746 --- /dev/null +++ b/ggml-cuda/fattn-vec-f16-q8_0-q8_0.cu @@ -0,0 +1,5 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-vec-f16.cuh" + +DECL_FATTN_VEC_INST(f16, 128, 1, 4, q8_0, q8_0); diff --git a/ggml-cuda/fattn-vec-f16.cu b/ggml-cuda/fattn-vec-f16.cu index c427e18ab..eac498c2d 100644 --- a/ggml-cuda/fattn-vec-f16.cu +++ b/ggml-cuda/fattn-vec-f16.cu @@ -2,287 +2,6 @@ #include "fattn-common.cuh" #include "fattn-vec-f16.cuh" -template // D == head size -#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -__launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) -static __global__ void flash_attn_vec_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int nb31, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { -#if FP16_AVAILABLE - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on. - const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel. - - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - Q += nb02* blockIdx.y + nb01*ic0; - K += nb12*(blockIdx.y / gqa_ratio); - V += nb22*(blockIdx.y / gqa_ratio); - - const half * maskh = (const half *) mask + ne11*ic0; - - const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - constexpr int nwarps = D / WARP_SIZE; - const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < D); - - __shared__ half KQ[ncols*D]; - half2 * KQ2 = (half2 *) KQ; - - half kqmax[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax[j] = -HALF_MAX_HALF; - } - half kqsum[ncols] = {0.0f}; - - __shared__ half kqmax_shared[ncols][WARP_SIZE]; - __shared__ half kqsum_shared[ncols][WARP_SIZE]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - if (threadIdx.y == 0) { - kqmax_shared[j][threadIdx.x] = -HALF_MAX_HALF; - kqsum_shared[j][threadIdx.x] = 0.0f; - } - } - __syncthreads(); - - // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: - half2 Q_h2[ncols][D/(2*WARP_SIZE)]; - int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D/(sizeof(int)*QK8_1)]; - half2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1]; - if (Q_q8_1) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > ncols && j >= ncols) { - break; - } - - // Reuse KQ as temporary storage for converting Q to q8_1: - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); - - // Set memory to zero if out of bounds: - if (ncols > 2 && ic0 + j >= ne01) { -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - tmp_q_i32[i] = 0; - } - if (threadIdx.x < D/QK8_1) { - tmp_q_ds[threadIdx.x] = make_half2(0.0f, 0.0f); - } - continue; - } - - const float * Q_f = (const float *) (Q + j*nb01); -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - quantize_q8_1_to_shared(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds); - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - int * tmp_q_i32 = (int *) &KQ[j*D]; - half2 * tmp_q_ds = (half2 *) (tmp_q_i32 + D/sizeof(int)); - -#pragma unroll - for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i]; - Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1]; - } - } - - __syncthreads(); - } else { -#pragma unroll - for (int j = 0; j < ncols; ++j) { - const float2 * Q_f2_j = (const float2 *) (Q + j*nb01); - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); - Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - } - - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - KQ[j*D + tid] = -HALF_MAX_HALF; - } - - half2 VKQ[ncols] = {{0.0f, 0.0f}}; - - const int k_start = parallel_blocks == 1 ? 0 : ip*D; - for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { - // Calculate KQ tile and keep track of new maximum KQ values: - - // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, - // see https://github.com/ggerganov/llama.cpp/pull/7061 . - // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). - half kqmax_new = kqmax[0]; - half kqmax_new_arr[ncols]; -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqmax_new_arr[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - - if ((i_KQ_0 + nwarps > D && i_KQ >= D) || (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + i_KQ >= ne11)) { - break; - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); - sum = warp_reduce_sum(sum); - sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); - - if (ncols == 1) { - kqmax_new = ggml_cuda_hmax(kqmax_new, sum); - } else { - kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum); - } - - if (threadIdx.x == 0) { - KQ[j*D + i_KQ] = sum; - } - } - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j]; - - kqmax_new_j = warp_reduce_max(kqmax_new_j); - if (threadIdx.x == 0) { - kqmax_shared[j][threadIdx.y] = kqmax_new_j; - } - } - - __syncthreads(); - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - half kqmax_new_j = kqmax_shared[j][threadIdx.x]; - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j); - kqmax[j] = kqmax_new_j; - - const half val = hexp(KQ[j*D + tid] - kqmax[j]); - kqsum[j] = kqsum[j]*KQ_max_scale + val; - KQ[j*D + tid] = val; - - VKQ[j] *= __half2half2(KQ_max_scale); - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < D; k0 += 2) { - if (FATTN_KQ_STRIDE % D != 0 && k_VKQ_0 + k0 >= ne11) { - break; - } - - half2 V_k; - reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); - reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); -#pragma unroll - for (int j = 0; j < ncols; ++j) { - VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; - } - } - - __syncthreads(); - } - -#pragma unroll - for (int j = 0; j < ncols; ++j) { - kqsum[j] = warp_reduce_sum(kqsum[j]); - if (threadIdx.x == 0) { - kqsum_shared[j][threadIdx.y] = kqsum[j]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { - if (ncols > 2 && ic0 + j_VKQ >= ne01) { - break; - } - - kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x]; - kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]); - - half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ])); - if (parallel_blocks == 1) { - dst_val /= kqsum[j_VKQ]; - } - const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; - dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val; - } - - if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { - dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); - } -#else - NO_DEVICE_CODE; -#endif // FP16_AVAILABLE -} - void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_tensor * KQV = dst; ggml_tensor * Q = dst->src[0];