integrate tensor cores

This commit is contained in:
FSSRepo 2024-01-26 20:14:02 -05:00
parent 6e7cb0eeaf
commit 0a481fe1a9

View file

@ -104,6 +104,7 @@
#include <cuda.h> #include <cuda.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <mma.h>
#if CUDART_VERSION < 11020 #if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
@ -621,6 +622,14 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
} }
static __device__ __forceinline__ __half warp_reduce_sum(__half x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
static __device__ __forceinline__ float warp_reduce_max(float x) { static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) { for (int mask = 16; mask > 0; mask >>= 1) {
@ -642,6 +651,19 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
} }
static __device__ __forceinline__ half warp_reduce_max(half x) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
#else
(void) x;
bad_arch();
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
}
static __device__ __forceinline__ float op_repeat(const float a, const float b) { static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b; return b;
GGML_UNUSED(a); GGML_UNUSED(a);
@ -6112,6 +6134,10 @@ static __global__ void flash_attn_f32(
} }
} }
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, __half> half16x16_acc;
// based on metal version // based on metal version
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
static __global__ void flash_attn_ext_f16( static __global__ void flash_attn_ext_f16(
@ -6152,17 +6178,17 @@ static __global__ void flash_attn_ext_f16(
const int D2 = D/2; const int D2 = D/2;
const int N4 = WARP_SIZE; const int N4 = WARP_SIZE;
const int L2 = (D2 + N4 - 1)/N4; const int L2 = (D2 + N4 - 1)/N4;
const int D8 = D/8; const int D16 = D/16;
const int T = D + n_warps*(D + 1*C); // shared memory size per query in half const int T = D + n_warps*(D + 1*C); // shared memory size per query in half
const int T2 = T/2; // shared memory size per query in half2 const int T2 = T/2; // shared memory size per query in half2
const half2 scale_h = __half2half2(__float2half(scale)); const half scale_h = __float2half(scale);
extern __shared__ char data_flash_attn_shmem[]; extern __shared__ char data_flash_attn_shmem[];
// pq
half * pq = (half *) (data_flash_attn_shmem + 0*D); half * pq = (half *) (data_flash_attn_shmem + 0*D);
half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D); half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D);
half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D); half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D); half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D);
@ -6191,120 +6217,185 @@ static __global__ void flash_attn_ext_f16(
__syncthreads(); __syncthreads();
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; {
half S[Q] = { 0.0 };
half M[Q] = { -INFINITY };
half S[8] = { 0.0 }; // assume K and V are same shape
#if 0 const int ne22 = ne12;
half2 M = make_half2(-INFINITY, -INFINITY); const int ne23 = ne13;
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; const int nb21 = nb11;
const int nb22 = nb12;
const int nb23 = nb13;
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { // broadcast
const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0); const int rk2 = ne02/ne12;
if (__hisinf(mv.x) == -1) { // mv == -INFINITY const int rk3 = ne03/ne13;
continue;
const int rv2 = ne02/ne22;
const int rv3 = ne03/ne23;
// k indices
const int ik2 = iq2 / rk2;
const int ik3 = iq3 / rk3;
// v indices
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
// TODO: this can be improved
float * mp[Q];
{
const int ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
for (int j = 0; j < Q; ++j) {
if (iq1 + j < ne01) {
mp[j] = (float *)(mask + ((ir + j)%ne31) * nb31);
} else {
mp[j] = nullptr;
}
}
} }
half2 * pk2 = (half2 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); for (int iic = C*warp_id; iic < ne11; iic += C*n_warps) {
half2 * pv2 = (half2 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K // skip -INF blocks
// TODO: double-check this
{
float smc = -INFINITY;
half2 s2 = make_half2(0.0, 0.0); for (int j = 0; j < Q; ++j) {
const float mc = mp[j] ? mp[j][iic + lane_id] : -INFINITY;
smc = warp_reduce_max(max(smc, mc));
}
#pragma unroll if (smc == -INFINITY) {
for (int i = 0; i < D2/tph; ++i) { continue;
s2 = pq2[hiiw*D2 + tph*i + tiih] * pk2[tph*i + tiih] + s2; }
}
ss[hiiw*tph + tiih] = __half2half2(s2.x + s2.y);
__syncthreads();
if (tiih == 0) {
half2 s = make_half2(0.0, 0.0);
#pragma unroll
for (int i = 0; i < tph; ++i) {
s += ss[hiiw*tph + i];
} }
s = s*scale_h + mv; // s*scale + mv // Q*K^T
{
half16x16_a mq{};
half16x16_b mk{};
half16x16_acc mqk{};
half2 m = M; for (int cc = 0; cc < C/16; ++cc) {
nvcuda::wmma::fill_fragment(mqk, 0); // re fetch
M = __hmax2(M, s); const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
half2 ms = h2exp(m - M); for(int i = 0; i < D16;i ++) {
half2 vs = h2exp(s - M); nvcuda::wmma::load_matrix_sync(mq, pq + i*16, T);
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
nvcuda::wmma::mma_sync(mqk, mq, mk, mqk);
}
S = S*ms + vs; nvcuda::wmma::store_matrix_sync(ss + 16*cc, mqk, T, nvcuda::wmma::mem_col_major);
}
}
ss[2*hiiw + 0] = ms; // online softmax
ss[2*hiiw + 1] = vs; for (int64_t j = 0; j < Q; ++j) {
const int64_t p = lane_id;
const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]);
half m = M[j];
M[j] = warp_reduce_max(__hmax(M[j], s));
const half ms = __hisinf(m) == -1 ? 0.0 : hexp(m - M[j]);
const half vs = __hisinf(s) == -1 ? 0.0 : hexp(s - M[j]);
S[j] = S[j]*ms + warp_reduce_sum(vs);
ss[j*T + p] = vs;
}
__syncthreads();
// (Q*K^T)*V
{
half16x16_acc mqkv{};
half16x16_a mqk{};
half16x16_b mv{};
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(mqkv, 0);
for (int cc = 0; cc < C/16; ++cc) {
const half * pv = (const half *) ((const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
nvcuda::wmma::load_matrix_sync(mqk, ss + cc*16, T);
nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half));
nvcuda::wmma::mma_sync(mqkv, mqk, mv, mqkv);
}
nvcuda::wmma::store_matrix_sync(ps + i*16, mqkv, T, nvcuda::wmma::mem_col_major);
}
}
} }
__syncthreads(); for (int64_t j = 0; j < Q; ++j) {
if (lane_id == 0) {
half2 ms = ss[2*hiiw + 0]; ss[j*T + 0] = S[j];
half2 vs = ss[2*hiiw + 1]; ss[j*T + 1] = M[j];
}
#pragma unroll
for (int i = 0; i < D2/tph; ++i) {
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms + pv2[tph*i + tiih]*vs;
} }
} }
if (tiih == 0) {
ss[2*hiiw + 0] = S;
ss[2*hiiw + 1] = M;
}
__syncthreads(); __syncthreads();
// reduce the warps // reduce the warps
// TODO: try parallel reduce
if (warp_id == 0) { if (warp_id == 0) {
for (int sg = 1; sg < nwraps; ++sg) { half S = 0.0;
half2 S0 = ss[ 2*hiiw + 0]; half M = -INFINITY;
half2 S1 = ss[sg*(R*D + 32) + 2*hiiw + 0];
half2 M0 = ss[ 2*hiiw + 1]; for (int64_t sg = 1; sg < n_warps; ++sg) {
half2 M1 = ss[sg*(R*D + 32) + 2*hiiw + 1]; for (int64_t j = 0; j < Q; ++j) {
const half S0 = ss[j*T + 0];
const half S1 = ss[j*T + sg*(D + 1*C) + 0];
M = __hmax2(M0, M1); const half M0 = ss[j*T + 1];
const half M1 = ss[j*T + sg*(D + 1*C) + 1];
half2 ms0 = h2exp(M0 - M); M = __hmax(M0, M1);
half2 ms1 = h2exp(M1 - M);
S = S0*ms0 + S1*ms1; const half ms0 = hexp(M0 - M);
const half ms1 = hexp(M1 - M);
if (tiih == 0) { S = S0*ms0 + S1*ms1;
ss[2*hiiw + 0] = S;
ss[2*hiiw + 1] = M; if (lane_id == 0) {
ss[j*T + 0] = S;
ss[j*T + 1] = M;
}
for (int64_t i = 0; i < L2; ++i) {
ps2[j*T2 + N4*i + lane_id] = ps2[j*T2 + N4*i + lane_id]*__half2half2(ms0) + ps2[j*T2 + sg*(D + 1*C)/4 + N4*i + lane_id]*__half2half2(ms1);
}
} }
for (int i = 0; i < D2/tph; ++i) {
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih]*ms1;
}
}
for (int i = 0; i < D2/tph; ++i) {
ps2[hiiw*D2 + tph*i + tiih] = __h2div(ps2[hiiw*D2 + tph*i + tiih], S);
} }
} }
__syncthreads(); __syncthreads();
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
float2 * dst2 = (float2 *) kqv; float2 * dst2 = (float2 *) kqv;
if (warp_id == 0) { if (warp_id == 0) {
for (int i = 0; i < D2/tph; ++i) { for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]); half2 S = __half2half2(ss[j*T + 0]);
for (int i = 0; i < L2; ++i) {
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + N4*i + lane_id] = __half22float2(ps2[j*T2 + N4*i + lane_id]/S);
}
} }
} }
#endif
} }
@ -10300,7 +10391,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
memcpy(&scale, KQV->op_params, sizeof(float)); memcpy(&scale, KQV->op_params, sizeof(float));
const int nwarps = Q->ne[1] < 4 ? 4 : 2; const int nwarps = Q->ne[1] < 4 ? 4 : 2;
const int nqpb = 2; // queries per block const int nqpb = 16; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values) const int ncpw = 32; // cache values per warp (does not work for other values)
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
@ -10311,7 +10402,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
switch (Q->ne[0]) switch (Q->ne[0])
{ {
case 64: case 64:
flash_attn_ext_f16<64, 8, 32> flash_attn_ext_f16<64, 16, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key
@ -10328,7 +10419,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
); );
break; break;
case 80: case 80:
flash_attn_ext_f16<80, 8, 32> flash_attn_ext_f16<80, 16, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key
@ -10345,7 +10436,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
); );
break; break;
case 128: case 128:
flash_attn_ext_f16<128, 8, 32> flash_attn_ext_f16<128, 16, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key