diff --git a/ggml-cuda/dequantize.cuh b/ggml-cuda/dequantize.cuh index bd3c2d9db..4c735e977 100644 --- a/ggml-cuda/dequantize.cuh +++ b/ggml-cuda/dequantize.cuh @@ -101,3 +101,11 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in v.y *= d; #endif // GGML_CUDA_F16 } + +static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const half * x = (const half *) vx; + + // automatic half -> float type cast if dfloat == float + v.x = x[ib + iqs + 0]; + v.y = x[ib + iqs + 1]; +} diff --git a/ggml-cuda/dmmv.cu b/ggml-cuda/dmmv.cu index 7313e3e17..be02b688d 100644 --- a/ggml-cuda/dmmv.cu +++ b/ggml-cuda/dmmv.cu @@ -565,14 +565,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, } } -static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ - const half * x = (const half *) vx; - - // automatic half -> float type cast if dfloat == float - v.x = x[ib + iqs + 0]; - v.y = x[ib + iqs + 1]; -} - template static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { // qk = quantized weights per x block diff --git a/ggml-cuda/fattn-common.cuh b/ggml-cuda/fattn-common.cuh index 1dd519bde..4adbcc6f4 100644 --- a/ggml-cuda/fattn-common.cuh +++ b/ggml-cuda/fattn-common.cuh @@ -94,7 +94,6 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern ggml_tensor * KQV = dst; GGML_ASSERT(Q->type == GGML_TYPE_F32); - GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(KQV->type == GGML_TYPE_F32); diff --git a/ggml-cuda/fattn-tile-f16.cu b/ggml-cuda/fattn-tile-f16.cu index 4a07ac6ad..694ba81ba 100644 --- a/ggml-cuda/fattn-tile-f16.cu +++ b/ggml-cuda/fattn-tile-f16.cu @@ -1,10 +1,11 @@ #include "common.cuh" +#include "dequantize.cuh" #include "fattn-common.cuh" #include "fattn-tile-f16.cuh" #define FATTN_KQ_STRIDE_TILE_F16 64 -template // D == head size +template // D == head size #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) __launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) @@ -48,10 +49,11 @@ static __global__ void flash_attn_tile_ext_f16( const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio)); + const type_k * K_h = (const type_k *) (K + nb12*(blockIdx.y / gqa_ratio)); const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape const half * maskh = (const half *) mask + ne11*ic0; + const int stride_K = nb11 / sizeof(type_k); const int stride_KV2 = nb11 / sizeof(half2); const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); @@ -108,7 +110,9 @@ static __global__ void flash_attn_tile_ext_f16( for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; - KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + half2 tmp; + dequantize_k(K_h, (k_VKQ_0 + i_KQ)*stride_K + (2*k_KQ)/qkk, (2*k_KQ)%qkk, tmp); + KV_tmp[i_KQ][k_KQ] = tmp; } } @@ -270,13 +274,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * case 64: { constexpr int D = 64; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; case 128: { constexpr int D = 128; constexpr int nwarps = 8; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; + fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block); } break; default: { diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index af7c95232..e6159ec69 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -457,11 +457,18 @@ void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; ggml_cuda_set_device(ctx.device); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int32_t precision = KQV->op_params[2]; + if (ggml_is_quantized(K->type) || ggml_is_quantized(V->type)) { + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + return; + } + // On AMD the tile kernels perform poorly, use the vec kernel instead: if (cc >= CC_OFFSET_AMD) { if (precision == GGML_PREC_DEFAULT) {