CUDA: quantized KV support for FA vec
This commit is contained in:
parent
10b1e45876
commit
672244a88b
11 changed files with 826 additions and 142 deletions
|
@ -2,7 +2,7 @@
|
|||
#include "fattn-common.cuh"
|
||||
#include "fattn-vec-f32.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks> // D == head size
|
||||
template<int D, int ncols, int parallel_blocks, vec_dot_KQ_f32_t vec_dot_KQ, bool Q_q8_1, dequantize_1_f32_t dequantize_1_v> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
|
@ -34,6 +34,9 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
const int nb11,
|
||||
const int nb12,
|
||||
const int nb13,
|
||||
const int nb21,
|
||||
const int nb22,
|
||||
const int nb23,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
|
@ -44,13 +47,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
Q += nb02* blockIdx.y + nb01*ic0;
|
||||
K += nb12*(blockIdx.y / gqa_ratio);
|
||||
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
||||
|
||||
|
@ -83,17 +83,69 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
// Convert Q to half2 and store in registers:
|
||||
float2 Q_h2[ncols][D/(2*WARP_SIZE)];
|
||||
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
||||
float2 Q_f2[ncols][D/(2*WARP_SIZE)];
|
||||
int Q_i32[ncols][D/(sizeof(int)*QK8_1) == 0 ? 1 : D >= D/(sizeof(int)*QK8_1)];
|
||||
float2 Q_ds[ncols][D/QK8_1 == 0 ? 1 : D/QK8_1];
|
||||
if (Q_q8_1) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
|
||||
Q_h2[j][i0/WARP_SIZE].x *= scale;
|
||||
Q_h2[j][i0/WARP_SIZE].y *= scale;
|
||||
// Reuse KQ as temporary storage for converting Q to q8_1:
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
// Set memory to zero if out of bounds:
|
||||
if (ncols > 2 && ic0 + j >= ne01) {
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
tmp_q_i32[i] = 0;
|
||||
}
|
||||
if (threadIdx.x < D/QK8_1) {
|
||||
tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * Q_f = (const float *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
int * tmp_q_i32 = (int *) &KQ[j*D];
|
||||
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
|
||||
Q_ds[j][i0/WARP_SIZE] = tmp_q_ds[i/QI8_1];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
const float2 * Q_f2_j = (const float2 *) (Q + j*nb01);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
|
||||
Q_f2[j][i0/WARP_SIZE].x *= scale;
|
||||
Q_f2[j][i0/WARP_SIZE].y *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,28 +169,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
break;
|
||||
}
|
||||
|
||||
float sum[ncols] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
|
||||
const int k_KQ = k_KQ_0 + threadIdx.x;
|
||||
|
||||
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
sum[j] += __low2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].x;
|
||||
sum[j] += __high2float(K_ik) * Q_h2[j][k_KQ_0/WARP_SIZE].y;
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
sum[j] = warp_reduce_sum(sum[j]);
|
||||
sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);
|
||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
KQ[j*D + i_KQ] = sum[j];
|
||||
KQ[j*D + i_KQ] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -178,7 +218,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
break;
|
||||
}
|
||||
|
||||
const float V_ki = __half2float(V_h[(k_VKQ_0 + k)*stride_KV + tid]);
|
||||
const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
VKQ[j] += V_ki*KQ[j*D + k];
|
||||
|
@ -223,23 +263,65 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
template <int cols_per_block, int parallel_blocks>
|
||||
void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
case 64: {
|
||||
constexpr int D = 64;
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
case 128: {
|
||||
constexpr int D = 128;
|
||||
constexpr int nwarps = D/WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
|
||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
||||
} break;
|
||||
}
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
#ifdef GGML_CUDA_FA_ALL_QUANTS
|
||||
FATTN_VEC_CASE(f32, 64, f16, q4_0)
|
||||
FATTN_VEC_CASE(f32, 64, f16, q4_1)
|
||||
FATTN_VEC_CASE(f32, 64, f16, q5_0)
|
||||
FATTN_VEC_CASE(f32, 64, f16, q5_1)
|
||||
FATTN_VEC_CASE(f32, 64, f16, q8_0)
|
||||
FATTN_VEC_CASE(f32, 64, f16, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q4_1, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_0, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q5_1, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, f16)
|
||||
|
||||
FATTN_VEC_CASE(f32, 128, f16, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, f16, q4_1)
|
||||
FATTN_VEC_CASE(f32, 128, f16, q5_0)
|
||||
FATTN_VEC_CASE(f32, 128, f16, q5_1)
|
||||
FATTN_VEC_CASE(f32, 128, f16, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, f16, f16)
|
||||
#else
|
||||
FATTN_VEC_CASE(f32, 128, q4_0, q4_0)
|
||||
FATTN_VEC_CASE(f32, 128, q8_0, q8_0)
|
||||
FATTN_VEC_CASE(f32, 128, f16, f16)
|
||||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
on_no_fattn_vec_case(Q->ne[0]);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue