4 warps, 256 stride for all D
This commit is contained in:
parent
269374ed81
commit
cca6d027a3
1 changed files with 147 additions and 486 deletions
|
@ -1,3 +1,4 @@
|
|||
#include "common.cuh"
|
||||
#include "fattn.cuh"
|
||||
|
||||
#include <mma.h>
|
||||
|
@ -176,8 +177,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
dst[D*gridDim.y*blockIdx.x + D*blockIdx.y + tid] = (__low2half(VKQ) + __high2half(VKQ)) / kqsum;
|
||||
}
|
||||
|
||||
template<int D, int ncols> // D == head size
|
||||
__launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1)
|
||||
#define FATTN_KQ_STRIDE 256
|
||||
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride> // D == head size, VKQ_stride == num VKQ rows calculated in parallel
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
|
@ -206,6 +209,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne2,
|
||||
const int ne3) {
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
|
||||
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
|
||||
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
||||
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
||||
|
@ -215,14 +219,13 @@ static __global__ void flash_attn_ext_f16(
|
|||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c;
|
||||
|
||||
constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m;
|
||||
constexpr int nthreads = nwarps*WARP_SIZE;
|
||||
static_assert(nthreads % D == 0, "nthreads not divisible by D.");
|
||||
constexpr int tc_vals_per_iter = nwarps*frag_m;
|
||||
static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter.");
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
__builtin_assume(tid < nthreads);
|
||||
constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts.
|
||||
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
||||
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
||||
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
|
||||
|
||||
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
|
||||
constexpr int D_padded = D + 8;
|
||||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x);
|
||||
|
@ -235,32 +238,44 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
frag_b Q_b[D/16][ncols/frag_n];
|
||||
|
||||
__shared__ half KQ[ncols*D_padded]; // Buffer for temporarily holding tiles of KQ.
|
||||
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||
constexpr int mem_KQ = ncols*kqs_padded;
|
||||
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
|
||||
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
|
||||
half2 * KQ2 = (half2 *) KQ;
|
||||
|
||||
half2 KQ_rowsum[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}};
|
||||
half2 KQ_max[(ncols + nwarps - 1) / nwarps] = {{-INFINITY, -INFINITY}};
|
||||
half2 KQ_max_scale[(ncols + nwarps - 1) / nwarps] = {{0.0f, 0.0f}};
|
||||
half2 KQ_rowsum[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
half2 KQ_max[ncols/nwarps] = {{-INFINITY, -INFINITY}};
|
||||
half2 KQ_max_scale[ncols/nwarps] = {{0.0f, 0.0f}};
|
||||
|
||||
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
|
||||
half2 * VKQ2 = (half2 *) VKQ;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < ncols*D_padded/2; i0 += nthreads) {
|
||||
const int i = i0 + tid;
|
||||
if (i0 + nthreads > ncols*D_padded/2 && i >= ncols*D_padded/2) {
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||
break;
|
||||
}
|
||||
|
||||
VKQ2[i] = make_half2(0.0f, 0.0f);
|
||||
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert Q to half and apply scale, temporarily store in KQ:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nthreads/D) {
|
||||
const int j = j0 + tid/D;
|
||||
const int i = tid % D;
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D && i >= D) {
|
||||
break;
|
||||
}
|
||||
KQ[j*D_padded + i] = ncols*blockIdx.x + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
@ -276,18 +291,15 @@ static __global__ void flash_attn_ext_f16(
|
|||
__syncthreads();
|
||||
|
||||
// Iterate over ne11 == previous tokens:
|
||||
for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += D) {
|
||||
const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11;
|
||||
|
||||
for (int k_VKQ_0 = 0; k_VKQ_0 < ne11; k_VKQ_0 += FATTN_KQ_STRIDE) {
|
||||
// Calculate tile of KQ:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
|
||||
frag_c KQ_c[ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
||||
}
|
||||
if (has_valid_data) {
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
||||
frag_a_K K_a;
|
||||
|
@ -297,10 +309,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
|
||||
nvcuda::wmma::store_matrix_sync(KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -311,18 +322,12 @@ static __global__ void flash_attn_ext_f16(
|
|||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
if (j0 + nwarps > ncols && j >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
half2 KQ_max_new = KQ_max[j0/nwarps];
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) {
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
if (k0 + WARP_SIZE > D/2 && k >= D/2) {
|
||||
break;
|
||||
}
|
||||
KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(D_padded/2) + k]);
|
||||
KQ_max_new = __hmax2(KQ_max_new, KQ2[j*(kqs_padded/2) + k]);
|
||||
}
|
||||
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
||||
KQ_max_scale[j0/nwarps] = h2exp(KQ_max[j0/nwarps] - KQ_max_new);
|
||||
|
@ -330,20 +335,14 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < D/2; k0 += WARP_SIZE) {
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
||||
const int k = k0 + threadIdx.x;
|
||||
if (k0 + WARP_SIZE > D/2 && k >= D/2) {
|
||||
break;
|
||||
}
|
||||
if (256 % D != 0 && k_VKQ_0 + 2*k >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
half2 val = KQ2[j*(D_padded/2) + k];
|
||||
half2 val = KQ2[j*(kqs_padded/2) + k];
|
||||
val += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
||||
val = h2exp(val - KQ_max[j0/nwarps]);
|
||||
KQ_rowsum_add += val;
|
||||
KQ2[j*(D_padded/2) + k] = val;
|
||||
KQ2[j*(kqs_padded/2) + k] = val;
|
||||
}
|
||||
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
||||
|
||||
|
@ -353,47 +352,46 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
__syncthreads();
|
||||
|
||||
frag_b KQ_b[D/16][ncols/frag_n];
|
||||
frag_b KQ_b[FATTN_KQ_STRIDE/16][ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < D; k0 += 16) {
|
||||
nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*D_padded + k0, D_padded);
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += 16) {
|
||||
nvcuda::wmma::load_matrix_sync(KQ_b[k0/16][j0/frag_n], KQ + j0*kqs_padded + k0, kqs_padded);
|
||||
}
|
||||
}
|
||||
|
||||
frag_c VKQ_c[D/tc_vals_per_iter][ncols/frag_n];
|
||||
frag_c VKQ_c[D/VKQ_stride][ncols/frag_n];
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
|
||||
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f);
|
||||
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < D; k0 += 16) {
|
||||
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
|
||||
break;
|
||||
}
|
||||
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
||||
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
||||
|
||||
frag_a_V v_a;
|
||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV);
|
||||
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols/frag_n; ++j) {
|
||||
nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]);
|
||||
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k/16][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
||||
nvcuda::wmma::store_matrix_sync(
|
||||
KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y,
|
||||
VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n],
|
||||
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
||||
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
||||
D_padded, nvcuda::wmma::mem_col_major);
|
||||
}
|
||||
}
|
||||
|
@ -403,16 +401,19 @@ static __global__ void flash_attn_ext_f16(
|
|||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
if (j0 + nwarps > ncols && j >= ncols) {
|
||||
break;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
||||
break;
|
||||
}
|
||||
VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + KQ2[j*(D_padded/2) + i];
|
||||
|
||||
half2 VKQ_add = make_half2(0.0f, 0.0f);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < VKQ_ratio; ++l) {
|
||||
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
|
||||
}
|
||||
VKQ2[j*(D_padded/2) + i] = KQ_max_scale[j0/nwarps]*VKQ2[j*(D_padded/2) + i] + VKQ_add;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -422,7 +423,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
if ((j0 + nwarps > ncols && j >= ncols) || ncols*blockIdx.x + j >= ne01) {
|
||||
if (ncols*blockIdx.x + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
const float KQ_rowsum_j = __low2float(KQ_rowsum[j0/nwarps]) + __high2float(KQ_rowsum[j0/nwarps]);
|
||||
|
@ -437,6 +438,50 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
}
|
||||
|
||||
constexpr int get_max_power_of_2(int x) {
|
||||
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
|
||||
}
|
||||
|
||||
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
|
||||
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
|
||||
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
|
||||
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
|
||||
|
||||
// Number of VKQ rows calculated in parallel:
|
||||
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
|
||||
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
|
||||
}
|
||||
|
||||
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
|
||||
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
|
||||
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
||||
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||
|
||||
#define FATTN_SWITCH_CASE(D, ncols, nwarps) \
|
||||
case ncols: { \
|
||||
constexpr int frag_m = (ncols) == 8 && (D) % 32 == 0 ? 32 : 16; \
|
||||
flash_attn_ext_f16<D, ncols, nwarps, get_VKQ_stride(D, nwarps, frag_m)> \
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> ( \
|
||||
(const char *) Q->data, \
|
||||
(const char *) K->data, \
|
||||
(const char *) V->data, \
|
||||
mask ? ((const char *) mask->data) : nullptr, \
|
||||
(float *) KQV->data, \
|
||||
scale, \
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], \
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3], \
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, \
|
||||
Q->nb[1], Q->nb[2], Q->nb[3], \
|
||||
K->nb[1], K->nb[2], K->nb[3], \
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] \
|
||||
); \
|
||||
} \
|
||||
break; \
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
@ -580,7 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
}
|
||||
|
||||
int cols_per_block;
|
||||
if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) {
|
||||
if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) {
|
||||
cols_per_block = 64;
|
||||
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
|
||||
cols_per_block = 32;
|
||||
|
@ -590,451 +635,67 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
cols_per_block = 8;
|
||||
}
|
||||
const int frag_m = cols_per_block == 8 ? 32 : 16;
|
||||
const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;
|
||||
const int nwarps = 4;
|
||||
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
const size_t shmem = 0;
|
||||
|
||||
switch (Q->ne[0]) {
|
||||
case 64: switch (cols_per_block) {
|
||||
case 8:
|
||||
flash_attn_ext_f16<64, 8>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 16:
|
||||
flash_attn_ext_f16<64, 16>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 32:
|
||||
flash_attn_ext_f16<64, 32>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 64:
|
||||
flash_attn_ext_f16<64, 64>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
FATTN_SWITCH_CASE(64, 8, nwarps);
|
||||
FATTN_SWITCH_CASE(64, 16, nwarps);
|
||||
FATTN_SWITCH_CASE(64, 32, nwarps);
|
||||
FATTN_SWITCH_CASE(64, 64, nwarps);
|
||||
default:
|
||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
} break;
|
||||
case 80: switch (cols_per_block) {
|
||||
// case 8:
|
||||
// fused_attn_vec_ext_f16<80, 8>
|
||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
// (const char *) Q->data, // Query
|
||||
// (const char *) K->data, // Key
|
||||
// (const char *) V->data, // Value
|
||||
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
// (float *) KQV->data, // dst
|
||||
// scale,
|
||||
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
// Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
// K->nb[1], K->nb[2], K->nb[3],
|
||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
// );
|
||||
// break;
|
||||
case 16:
|
||||
flash_attn_ext_f16<80, 16>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 32:
|
||||
flash_attn_ext_f16<80, 32>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 64:
|
||||
flash_attn_ext_f16<80, 64>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
// FATTN_SWITCH_CASE(80, 8, nwarps);
|
||||
FATTN_SWITCH_CASE(80, 16, nwarps);
|
||||
FATTN_SWITCH_CASE(80, 32, nwarps);
|
||||
// FATTN_SWITCH_CASE(80, 64, nwarps);
|
||||
default:
|
||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
} break;
|
||||
case 96: switch (cols_per_block) {
|
||||
case 8:
|
||||
flash_attn_ext_f16<96, 8>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 16:
|
||||
flash_attn_ext_f16<96, 16>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 32:
|
||||
flash_attn_ext_f16<96, 32>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 64:
|
||||
flash_attn_ext_f16<96, 64>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
FATTN_SWITCH_CASE(96, 8, nwarps);
|
||||
FATTN_SWITCH_CASE(96, 16, nwarps);
|
||||
FATTN_SWITCH_CASE(96, 32, nwarps);
|
||||
FATTN_SWITCH_CASE(96, 64, nwarps);
|
||||
default:
|
||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
} break;
|
||||
case 112: switch (cols_per_block) {
|
||||
// case 8:
|
||||
// fused_attn_vec_ext_f16<112, 8>
|
||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
// (const char *) Q->data, // Query
|
||||
// (const char *) K->data, // Key
|
||||
// (const char *) V->data, // Value
|
||||
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
// (float *) KQV->data, // dst
|
||||
// scale,
|
||||
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
// Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
// K->nb[1], K->nb[2], K->nb[3],
|
||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
// );
|
||||
// break;
|
||||
case 16:
|
||||
flash_attn_ext_f16<112, 16>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 32:
|
||||
flash_attn_ext_f16<112, 32>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 64:
|
||||
flash_attn_ext_f16<112, 64>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
// FATTN_SWITCH_CASE(112, 8, nwarps);
|
||||
FATTN_SWITCH_CASE(112, 16, nwarps);
|
||||
FATTN_SWITCH_CASE(112, 32, nwarps);
|
||||
// FATTN_SWITCH_CASE(112, 64, nwarps);
|
||||
default:
|
||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
} break;
|
||||
case 128: switch (cols_per_block) {
|
||||
case 8:
|
||||
flash_attn_ext_f16<128, 8>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 16:
|
||||
flash_attn_ext_f16<128, 16>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 32:
|
||||
flash_attn_ext_f16<128, 32>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 64:
|
||||
flash_attn_ext_f16<128, 64>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
FATTN_SWITCH_CASE(128, 8, nwarps);
|
||||
FATTN_SWITCH_CASE(128, 16, nwarps);
|
||||
FATTN_SWITCH_CASE(128, 32, nwarps);
|
||||
// FATTN_SWITCH_CASE(128, 64, nwarps);
|
||||
default:
|
||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
} break;
|
||||
case 256: switch (cols_per_block) {
|
||||
case 8:
|
||||
flash_attn_ext_f16<256, 8>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 16:
|
||||
flash_attn_ext_f16<256, 16>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
case 32:
|
||||
flash_attn_ext_f16<256, 32>
|
||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
(const char *) Q->data, // Query
|
||||
(const char *) K->data, // Key
|
||||
(const char *) V->data, // Value
|
||||
mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
(float *) KQV->data, // dst
|
||||
scale,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
K->nb[1], K->nb[2], K->nb[3],
|
||||
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
);
|
||||
break;
|
||||
// case 64:
|
||||
// flash_attn_ext_f16<256, 64>
|
||||
// <<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||
// (const char *) Q->data, // Query
|
||||
// (const char *) K->data, // Key
|
||||
// (const char *) V->data, // Value
|
||||
// mask ? ((const char *) mask->data) : nullptr, // Mask
|
||||
// (float *) KQV->data, // dst
|
||||
// scale,
|
||||
// Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
// K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
// mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
// Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
// K->nb[1], K->nb[2], K->nb[3],
|
||||
// KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||
// );
|
||||
// break;
|
||||
FATTN_SWITCH_CASE(256, 8, nwarps);
|
||||
FATTN_SWITCH_CASE(256, 16, nwarps);
|
||||
FATTN_SWITCH_CASE(256, 32, nwarps);
|
||||
// FATTN_SWITCH_CASE(256, 64, nwarps);
|
||||
default:
|
||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||
GGML_ASSERT(false);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue