Calculate KQ as FP32 if KQV has GGML_PREC_F32
This commit is contained in:
parent
a5b0e2dea0
commit
0bc67dd1c8
1 changed files with 213 additions and 73 deletions
|
@ -1,6 +1,7 @@
|
|||
#include "common.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
#include <mma.h>
|
||||
|
||||
#define FATTN_KQ_STRIDE 256
|
||||
|
@ -185,7 +186,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
|
@ -229,7 +231,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
||||
|
||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||
|
@ -238,12 +241,14 @@ static __global__ void flash_attn_ext_f16(
|
|||
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
|
||||
constexpr int D_padded = D + 8;
|
||||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half2 * mask2 = (const half2 *) mask + ne11*(ic0/2);
|
||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
|
@ -251,14 +256,29 @@ static __global__ void flash_attn_ext_f16(
|
|||
frag_b Q_b[D/16][ncols/frag_n];
|
||||
|
||||
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||
constexpr int mem_KQ = ncols*kqs_padded;
|
||||
constexpr int mem_KQ = ncols*kqs_padded*kqar;
|
||||
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
|
||||
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
||||
float * KQ_f = (float *) KQ;
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
half2 KQ_rowsum[ncols/nwarps] = {{ 0.0f, 0.0f}};
|
||||
half2 KQ_max[ncols/nwarps] = {{-HALF_MAX_HALF, -HALF_MAX_HALF}};
|
||||
half2 KQ_max_scale[ncols/nwarps] = {{ 0.0f, 0.0f}};
|
||||
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
|
||||
float KQ_max_f[ncols/nwarps];
|
||||
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
KQ_max_f[j] = -FLT_MAX/2.0f;
|
||||
}
|
||||
|
||||
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
half2 KQ_max_h2[ncols/nwarps];
|
||||
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/nwarps; ++j) {
|
||||
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
|
||||
}
|
||||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
|
@ -307,7 +327,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
// Calculate tile of KQ:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
||||
frag_c KQ_c[ncols/frag_n];
|
||||
frag_c_KQ KQ_c[ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
||||
|
@ -323,7 +343,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
||||
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -335,6 +355,50 @@ static __global__ void flash_attn_ext_f16(
|
|||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
||||
}
|
||||
|
||||
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
||||
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
KQ_max_new = warp_reduce_max(KQ_max_new);
|
||||
|
||||
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
||||
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
||||
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||
KQ_max_scale_f[j0/nwarps] = 0.0f;
|
||||
}
|
||||
KQ_max_f[j0/nwarps] = KQ_max_new;
|
||||
|
||||
float KQ_rowsum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
|
||||
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
|
||||
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
||||
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
|
||||
}
|
||||
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
|
||||
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
|
||||
}
|
||||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||
|
||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
||||
} else {
|
||||
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
|
@ -343,7 +407,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||
}
|
||||
|
||||
half2 KQ_max_new = KQ_max[j0/nwarps];
|
||||
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
@ -352,18 +416,18 @@ static __global__ void flash_attn_ext_f16(
|
|||
KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
||||
}
|
||||
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||
const half2 diff = KQ_max[j0/nwarps] - KQ_max_new;
|
||||
KQ_max_scale[j0/nwarps] = h2exp(diff);
|
||||
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
||||
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
||||
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||
*((uint *) &KQ_max_scale[j0/nwarps]) &= ftz_mask;
|
||||
KQ_max[j0/nwarps] = KQ_max_new;
|
||||
*((uint *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
|
||||
KQ_max_h2[j0/nwarps] = KQ_max_new;
|
||||
|
||||
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
|
||||
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max[j0/nwarps];
|
||||
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
||||
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
||||
const uint ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
||||
*((uint *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
||||
|
@ -373,7 +437,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||
|
||||
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
||||
KQ_rowsum[j0/nwarps] = KQ_max_scale[j0/nwarps]*KQ_rowsum[j0/nwarps] + KQ_rowsum_add;
|
||||
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
@ -386,12 +451,12 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
nvcuda::wmma::load_matrix_sync(
|
||||
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
||||
KQ + j0*kqs_padded + k,
|
||||
kqs_padded);
|
||||
KQ + j0*(kqar*kqs_padded) + k,
|
||||
kqar*kqs_padded);
|
||||
}
|
||||
}
|
||||
|
||||
frag_c VKQ_c[D/VKQ_stride][ncols/frag_n];
|
||||
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
|
@ -431,6 +496,14 @@ static __global__ void flash_attn_ext_f16(
|
|||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
half2 VKQ_scale;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
|
||||
} else {
|
||||
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
@ -443,7 +516,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
for (int l = 0; l < VKQ_ratio; ++l) {
|
||||
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
|
||||
}
|
||||
VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
||||
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -458,14 +531,20 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
|
||||
|
||||
const half KQ_rowsum_j = __low2half(KQ_rowsum[j0/nwarps]) + __high2half(KQ_rowsum[j0/nwarps]);
|
||||
float KQ_rowsum_j;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
|
||||
} else {
|
||||
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D && i >= D) {
|
||||
break;
|
||||
}
|
||||
half dst_val = VKQ[j_VKQ*D_padded + i];
|
||||
float dst_val = VKQ[j_VKQ*D_padded + i];
|
||||
if (parallel_blocks == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
|
@ -476,7 +555,12 @@ static __global__ void flash_attn_ext_f16(
|
|||
continue;
|
||||
}
|
||||
|
||||
half2 dst_meta_val = KQ_max[j0/nwarps];
|
||||
half2 dst_meta_val;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
reinterpret_cast<half&>(dst_meta_val.x) = KQ_max_f[j0/nwarps];
|
||||
} else {
|
||||
dst_meta_val = KQ_max_h2[j0/nwarps];
|
||||
}
|
||||
reinterpret_cast<half&>(dst_meta_val.y) = KQ_rowsum_j;
|
||||
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
|
||||
}
|
||||
|
@ -606,7 +690,7 @@ template <int D, int parallel_blocks> void launch_fattn_vec_f16(
|
|||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int nwarps, int parallel_blocks> void launch_fattn_f16_impl(
|
||||
template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
|
||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||
ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||
) {
|
||||
|
@ -626,7 +710,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks> void launc
|
|||
float scale;
|
||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks>
|
||||
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data,
|
||||
(const char *) K->data,
|
||||
|
@ -657,21 +741,21 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks> void launc
|
|||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
template <int D, int cols_per_block, int nwarps> void launch_fattn_f16(
|
||||
template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
|
||||
const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
||||
const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
|
||||
) {
|
||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||
|
||||
if (4*blocks_num_pb1 < 2*nsm) {
|
||||
launch_fattn_f16_impl<D, cols_per_block, nwarps, 4>(Q, K, V, KQV, mask, pool, main_stream);
|
||||
launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
||||
return;
|
||||
}
|
||||
if (2*blocks_num_pb1 < 2*nsm) {
|
||||
launch_fattn_f16_impl<D, cols_per_block, nwarps, 2>(Q, K, V, KQV, mask, pool, main_stream);
|
||||
launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
||||
return;
|
||||
}
|
||||
launch_fattn_f16_impl<D, cols_per_block, nwarps, 1>(Q, K, V, KQV, mask, pool, main_stream);
|
||||
launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
@ -696,15 +780,73 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % WARP_SIZE == 0) {
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
const int32_t precision = KQV->op_params[1];
|
||||
|
||||
if (precision != GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||
constexpr int cols_per_block = 16;
|
||||
constexpr int nwarps = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 80:
|
||||
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 96:
|
||||
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 112:
|
||||
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 256:
|
||||
launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
constexpr int cols_per_block = 32;
|
||||
constexpr int nwarps = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 80:
|
||||
launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 96:
|
||||
launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 112:
|
||||
launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
// case 256:
|
||||
// launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
// break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||
constexpr int parallel_blocks = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_vec_f16< 64, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 96:
|
||||
launch_fattn_vec_f16< 96, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_vec_f16<128, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
|
@ -718,23 +860,21 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
return;
|
||||
}
|
||||
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
||||
constexpr int cols_per_block = 8;
|
||||
constexpr int nwarps = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 96:
|
||||
launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 256:
|
||||
launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
@ -748,22 +888,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
constexpr int nwarps = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 80:
|
||||
launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 96:
|
||||
launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 112:
|
||||
launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 256:
|
||||
launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
@ -776,22 +916,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
constexpr int nwarps = 4;
|
||||
switch (Q->ne[0]) {
|
||||
case 64:
|
||||
launch_fattn_f16< 64, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 80:
|
||||
launch_fattn_f16< 80, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 96:
|
||||
launch_fattn_f16< 96, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 112:
|
||||
launch_fattn_f16<112, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 128:
|
||||
launch_fattn_f16<128, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
case 256:
|
||||
launch_fattn_f16<256, cols_per_block, nwarps>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue