4 warps, 256 stride for all D

This commit is contained in:
Johannes Gäßler 2024-03-31 18:39:02 +02:00 committed by Georgi Gerganov
parent 269374ed81
commit cca6d027a3

View file

@ -1,3 +1,4 @@
#include "common.cuh"
#include "fattn.cuh" #include "fattn.cuh"
#include <mma.h> #include <mma.h>
@ -176,8 +177,10 @@ static __global__ void flash_attn_vec_ext_f16(
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum; dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum;
} }
template<int D, int ncols> // D == head size #define FATTN_KQ_STRIDE 256
__launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1)
template<int D, int ncols, int nwarps, int VKQ_stride> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
__launch_bounds__(nwarps*WARP_SIZE, 1)
static __global__ void flash_attn_ext_f16( static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q, const char * __restrict__ Q,
const char * __restrict__ K, const char * __restrict__ K,
@ -206,6 +209,7 @@ static __global__ void flash_attn_ext_f16(
const int ne2, const int ne2,
const int ne3) { const int ne3) {
//In this kernel Q, K, V are matrices while i, j, k are matrix indices. //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16;
@ -215,14 +219,13 @@ static __global__ void flash_attn_ext_f16(
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c; typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c;
constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m; constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int nthreads = nwarps*WARP_SIZE; constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
static_assert(nthreads % D == 0, "nthreads not divisible by D."); static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
constexpr int tc_vals_per_iter = nwarps*frag_m;
static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter."); // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; constexpr int D_padded = D + 8;
__builtin_assume(tid < nthreads); constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x); const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x);
@ -235,31 +238,43 @@ static __global__ void flash_attn_ext_f16(
frag_b Q_b[D/16][ncols/frag_n]; frag_b Q_b[D/16][ncols/frag_n];
__shared__ half KQ[ncols*D_padded]; // Buffer for temporarily holding tiles of KQ. // A single buffer for temporarily holding tiles of KQ and VKQ parts:
constexpr int mem_KQ = ncols*kqs_padded;
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
half2 * KQ2 = (half2 *) KQ; half2 * KQ2 = (half2 *) KQ;
half2 KQ_rowsum[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_max[(ncols + nwarps - 1) / nwarps] = {{-INFINITY, -INFINITY}}; half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}};
half2 KQ_max_scale[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}}; half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}};
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ; half2 * VKQ2 = (half2 *) VKQ;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < ncols*D_padded/2; i0 += nthreads) { for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int i = i0 + tid; const int j = j0 + threadIdx.y;
if (i0 + nthreads > ncols*D_padded/2 && i >= ncols*D_padded/2) { #pragma unroll
break; for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break;
}
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
} }
VKQ2[i] = make_half2(0.0f, 0.0f);
} }
// Convert Q to half and apply scale, temporarily store in KQ: // Convert Q to half and apply scale, temporarily store in KQ:
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nthreads/D) { for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + tid/D; const int j = j0 + threadIdx.y;
const int i = tid % D; #pragma unroll
KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
} }
__syncthreads(); __syncthreads();
@ -276,31 +291,27 @@ static __global__ void flash_attn_ext_f16(
__syncthreads(); __syncthreads();
// Iterate over ne11 == previous tokens: // Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) { for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) {
const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11;
// Calculate tile of KQ: // Calculate tile of KQ:
#pragma unroll #pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
frag_c KQ_c[ncols/frag_n]; frag_c KQ_c[ncols/frag_n];
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
} }
if (has_valid_data) {
#pragma unroll #pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a; frag_a_K K_a;
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
}
} }
} }
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major); nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
} }
} }
@ -311,18 +322,12 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) { for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y; const int j = j0 + threadIdx.y;
if (j0 + nwarps > ncols && j >= ncols) {
break;
}
half2 KQ_max_new = KQ_max[j0/nwarps]; half2 KQ_max_new = KQ_max[j0/nwarps];
#pragma unroll #pragma unroll
for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x; const int k = k0 + threadIdx.x;
if (k0 + WARP_SIZE > D/2 && k >= D/2) { KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]);
break;
}
KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(D_padded/2) + k]);
} }
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new); KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new);
@ -330,20 +335,14 @@ static __global__ void flash_attn_ext_f16(
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
#pragma unroll #pragma unroll
for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) { for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x; const int k = k0 + threadIdx.x;
if (k0 + WARP_SIZE > D/2 && k >= D/2) {
break;
}
if (256 % D != 0 && k_VKQ_0 + 2*k >= ne11) {
break;
}
half2 val = KQ2[j*(D_padded/2) + k]; half2 val = KQ2[j*(kqs_padded/2) + k];
val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
val = h2exp(val - KQ_max[j0/nwarps]); val = h2exp(val - KQ_max[j0/nwarps]);
KQ_rowsum_add += val; KQ_rowsum_add += val;
KQ2[j*(D_padded/2) + k] = val; KQ2[j*(kqs_padded/2) + k] = val;
} }
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
@ -353,47 +352,46 @@ static __global__ void flash_attn_ext_f16(
__syncthreads(); __syncthreads();
frag_b KQ_b[D/16][ncols/frag_n]; frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n];
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#pragma unroll #pragma unroll
for (int k0 = 0; k0 < D; k0 += 16) { for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) {
nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*D_padded + k0, D_padded); nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded);
} }
} }
frag_c VKQ_c[D/tc_vals_per_iter][ncols/frag_n]; frag_c VKQ_c[D/VKQ_stride][ncols/frag_n];
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll #pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f); nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
} }
#pragma unroll #pragma unroll
for (int k0 = 0; k0 < D; k0 += 16) { for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) { const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
break;
}
frag_a_V v_a; frag_a_V v_a;
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV); nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll #pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) { for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]); nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
} }
} }
} }
__syncthreads(); __syncthreads();
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
#pragma unroll #pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) { for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) { for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync( nvcuda::wmma::store_matrix_sync(
KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n], VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major); D_padded, nvcuda::wmma::mem_col_major);
} }
} }
@ -403,16 +401,19 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) { for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y; const int j = j0 + threadIdx.y;
if (j0 + nwarps > ncols && j >= ncols) {
break;
}
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) { if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break; break;
} }
VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + KQ2[j*(D_padded/2) + i];
half2 VKQ_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int l = 0; l < VKQ_ratio; ++l) {
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
}
VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add;
} }
} }
@ -422,7 +423,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll #pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) { for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y; const int j = j0 + threadIdx.y;
if ((j0 + nwarps > ncols && j >= ncols) || ncols*blockIdx.x + j >= ne01) { if (ncols*blockIdx.x + j >= ne01) {
return; return;
} }
const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]); const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]);
@ -437,6 +438,50 @@ static __global__ void flash_attn_ext_f16(
} }
} }
constexpr int get_max_power_of_2(int x) {
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
}
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
// Number of VKQ rows calculated in parallel:
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
}
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
#define FATTN_SWITCH_CASE(D, ncols, nwarps) \
case ncols: { \
constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \
flash_attn_ext_f16<D, ncols, nwarps, get_VKQ_stride(D, nwarps, frag_m)> \
<<<blocks_num, block_dim, shmem, main_stream>>> ( \
(const char *) Q->data, \
(const char *) K->data, \
(const char *) V->data, \
mask ? ((const char *) mask->data) : nullptr, \
(float *) KQV->data, \
scale, \
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \
K->ne[0], K->ne[1], K->ne[2], K->ne[3], \
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \
Q->nb[1], Q->nb[2], Q->nb[3], \
K->nb[1], K->nb[2], K->nb[3], \
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \
); \
} \
break; \
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0]; const ggml_tensor * Q = dst->src[0];
@ -580,7 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
} }
int cols_per_block; int cols_per_block;
if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) {
cols_per_block = 64; cols_per_block = 64;
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
cols_per_block = 32; cols_per_block = 32;
@ -590,451 +635,67 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
cols_per_block = 8; cols_per_block = 8;
} }
const int frag_m = cols_per_block == 8 ? 32 : 16; const int frag_m = cols_per_block == 8 ? 32 : 16;
const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; const int nwarps = 4;
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const dim3 block_dim(WARP_SIZE, nwarps, 1); const dim3 block_dim(WARP_SIZE, nwarps, 1);
const size_t shmem = 0; const size_t shmem = 0;
switch (Q->ne[0]) { switch (Q->ne[0]) {
case 64: switch (cols_per_block) { case 64: switch (cols_per_block) {
case 8: FATTN_SWITCH_CASE(64, 8, nwarps);
flash_attn_ext_f16<64, 8> FATTN_SWITCH_CASE(64, 16, nwarps);
<<<blocks_num, block_dim, shmem, main_stream>>> ( FATTN_SWITCH_CASE(64, 32, nwarps);
(const char *) Q->data, // Query FATTN_SWITCH_CASE(64, 64, nwarps);
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 16:
flash_attn_ext_f16<64, 16>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 32:
flash_attn_ext_f16<64, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 64:
flash_attn_ext_f16<64, 64>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
default: default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
} break; } break;
case 80: switch (cols_per_block) { case 80: switch (cols_per_block) {
// case 8: // FATTN_SWITCH_CASE(80, 8, nwarps);
// fused_attn_vec_ext_f16<80, 8> FATTN_SWITCH_CASE(80, 16, nwarps);
// <<<blocks_num, block_dim, shmem, main_stream>>> ( FATTN_SWITCH_CASE(80, 32, nwarps);
// (const char *) Q->data, // Query // FATTN_SWITCH_CASE(80, 64, nwarps);
// (const char *) K->data, // Key
// (const char *) V->data, // Value
// mask ? ((const char *) mask->data) : nullptr, // Mask
// (float *) KQV->data, // dst
// scale,
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
// Q->nb[1], Q->nb[2], Q->nb[3],
// K->nb[1], K->nb[2], K->nb[3],
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
case 16:
flash_attn_ext_f16<80, 16>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 32:
flash_attn_ext_f16<80, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 64:
flash_attn_ext_f16<80, 64>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
default: default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
} break; } break;
case 96: switch (cols_per_block) { case 96: switch (cols_per_block) {
case 8: FATTN_SWITCH_CASE(96, 8, nwarps);
flash_attn_ext_f16<96, 8> FATTN_SWITCH_CASE(96, 16, nwarps);
<<<blocks_num, block_dim, shmem, main_stream>>> ( FATTN_SWITCH_CASE(96, 32, nwarps);
(const char *) Q->data, // Query FATTN_SWITCH_CASE(96, 64, nwarps);
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 16:
flash_attn_ext_f16<96, 16>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 32:
flash_attn_ext_f16<96, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 64:
flash_attn_ext_f16<96, 64>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
default: default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
} break; } break;
case 112: switch (cols_per_block) { case 112: switch (cols_per_block) {
// case 8: // FATTN_SWITCH_CASE(112, 8, nwarps);
// fused_attn_vec_ext_f16<112, 8> FATTN_SWITCH_CASE(112, 16, nwarps);
// <<<blocks_num, block_dim, shmem, main_stream>>> ( FATTN_SWITCH_CASE(112, 32, nwarps);
// (const char *) Q->data, // Query // FATTN_SWITCH_CASE(112, 64, nwarps);
// (const char *) K->data, // Key
// (const char *) V->data, // Value
// mask ? ((const char *) mask->data) : nullptr, // Mask
// (float *) KQV->data, // dst
// scale,
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
// Q->nb[1], Q->nb[2], Q->nb[3],
// K->nb[1], K->nb[2], K->nb[3],
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
case 16:
flash_attn_ext_f16<112, 16>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 32:
flash_attn_ext_f16<112, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 64:
flash_attn_ext_f16<112, 64>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
default: default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
} break; } break;
case 128: switch (cols_per_block) { case 128: switch (cols_per_block) {
case 8: FATTN_SWITCH_CASE(128, 8, nwarps);
flash_attn_ext_f16<128, 8> FATTN_SWITCH_CASE(128, 16, nwarps);
<<<blocks_num, block_dim, shmem, main_stream>>> ( FATTN_SWITCH_CASE(128, 32, nwarps);
(const char *) Q->data, // Query // FATTN_SWITCH_CASE(128, 64, nwarps);
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 16:
flash_attn_ext_f16<128, 16>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 32:
flash_attn_ext_f16<128, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 64:
flash_attn_ext_f16<128, 64>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
default: default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
} break; } break;
case 256: switch (cols_per_block) { case 256: switch (cols_per_block) {
case 8: FATTN_SWITCH_CASE(256, 8, nwarps);
flash_attn_ext_f16<256, 8> FATTN_SWITCH_CASE(256, 16, nwarps);
<<<blocks_num, block_dim, shmem, main_stream>>> ( FATTN_SWITCH_CASE(256, 32, nwarps);
(const char *) Q->data, // Query // FATTN_SWITCH_CASE(256, 64, nwarps);
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 16:
flash_attn_ext_f16<256, 16>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 32:
flash_attn_ext_f16<256, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) Q->data, // Query
(const char *) K->data, // Key
(const char *) V->data, // Value
mask ? ((const char *) mask->data) : nullptr, // Mask
(float *) KQV->data, // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
// case 64:
// flash_attn_ext_f16<256, 64>
// <<<blocks_num, block_dim, shmem, main_stream>>> (
// (const char *) Q->data, // Query
// (const char *) K->data, // Key
// (const char *) V->data, // Value
// mask ? ((const char *) mask->data) : nullptr, // Mask
// (float *) KQV->data, // dst
// scale,
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
// Q->nb[1], Q->nb[2], Q->nb[3],
// K->nb[1], K->nb[2], K->nb[3],
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
// );
// break;
default: default:
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
GGML_ASSERT(false); GGML_ASSERT(false);