integrate tensor cores
This commit is contained in:
parent
6e7cb0eeaf
commit
0a481fe1a9
1 changed files with 177 additions and 86 deletions
263
ggml-cuda.cu
263
ggml-cuda.cu
|
@ -104,6 +104,7 @@
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cublas_v2.h>
|
#include <cublas_v2.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
#include <mma.h>
|
||||||
|
|
||||||
#if CUDART_VERSION < 11020
|
#if CUDART_VERSION < 11020
|
||||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||||
|
@ -621,6 +622,14 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ __half warp_reduce_sum(__half x) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
@ -642,6 +651,19 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half warp_reduce_max(half x) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
#else
|
||||||
|
(void) x;
|
||||||
|
bad_arch();
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
||||||
return b;
|
return b;
|
||||||
GGML_UNUSED(a);
|
GGML_UNUSED(a);
|
||||||
|
@ -6112,6 +6134,10 @@ static __global__ void flash_attn_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_a;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_b;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, __half> half16x16_acc;
|
||||||
|
|
||||||
// based on metal version
|
// based on metal version
|
||||||
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
|
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
|
||||||
static __global__ void flash_attn_ext_f16(
|
static __global__ void flash_attn_ext_f16(
|
||||||
|
@ -6152,17 +6178,17 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int D2 = D/2;
|
const int D2 = D/2;
|
||||||
const int N4 = WARP_SIZE;
|
const int N4 = WARP_SIZE;
|
||||||
const int L2 = (D2 + N4 - 1)/N4;
|
const int L2 = (D2 + N4 - 1)/N4;
|
||||||
const int D8 = D/8;
|
const int D16 = D/16;
|
||||||
|
|
||||||
const int T = D + n_warps*(D + 1*C); // shared memory size per query in half
|
const int T = D + n_warps*(D + 1*C); // shared memory size per query in half
|
||||||
const int T2 = T/2; // shared memory size per query in half2
|
const int T2 = T/2; // shared memory size per query in half2
|
||||||
|
|
||||||
const half2 scale_h = __half2half2(__float2half(scale));
|
const half scale_h = __float2half(scale);
|
||||||
|
|
||||||
extern __shared__ char data_flash_attn_shmem[];
|
extern __shared__ char data_flash_attn_shmem[];
|
||||||
|
// pq
|
||||||
half * pq = (half *) (data_flash_attn_shmem + 0*D);
|
half * pq = (half *) (data_flash_attn_shmem + 0*D);
|
||||||
half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D);
|
half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D);
|
||||||
half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
|
half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
|
||||||
half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
|
half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
|
||||||
half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D);
|
half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D);
|
||||||
|
@ -6191,120 +6217,185 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
{
|
||||||
|
half S[Q] = { 0.0 };
|
||||||
|
half M[Q] = { -INFINITY };
|
||||||
|
|
||||||
half S[8] = { 0.0 };
|
// assume K and V are same shape
|
||||||
#if 0
|
const int ne22 = ne12;
|
||||||
half2 M = make_half2(-INFINITY, -INFINITY);
|
const int ne23 = ne13;
|
||||||
|
|
||||||
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
|
const int nb21 = nb11;
|
||||||
|
const int nb22 = nb12;
|
||||||
|
const int nb23 = nb13;
|
||||||
|
|
||||||
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
|
// broadcast
|
||||||
const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0);
|
const int rk2 = ne02/ne12;
|
||||||
if (__hisinf(mv.x) == -1) { // mv == -INFINITY
|
const int rk3 = ne03/ne13;
|
||||||
continue;
|
|
||||||
|
const int rv2 = ne02/ne22;
|
||||||
|
const int rv3 = ne03/ne23;
|
||||||
|
|
||||||
|
// k indices
|
||||||
|
const int ik2 = iq2 / rk2;
|
||||||
|
const int ik3 = iq3 / rk3;
|
||||||
|
|
||||||
|
// v indices
|
||||||
|
const int iv2 = iq2 / rv2;
|
||||||
|
const int iv3 = iq3 / rv3;
|
||||||
|
|
||||||
|
// TODO: this can be improved
|
||||||
|
float * mp[Q];
|
||||||
|
|
||||||
|
{
|
||||||
|
const int ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
||||||
|
|
||||||
|
for (int j = 0; j < Q; ++j) {
|
||||||
|
if (iq1 + j < ne01) {
|
||||||
|
mp[j] = (float *)(mask + ((ir + j)%ne31) * nb31);
|
||||||
|
} else {
|
||||||
|
mp[j] = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
half2 * pk2 = (half2 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
|
for (int iic = C*warp_id; iic < ne11; iic += C*n_warps) {
|
||||||
half2 * pv2 = (half2 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K
|
// skip -INF blocks
|
||||||
|
// TODO: double-check this
|
||||||
|
{
|
||||||
|
float smc = -INFINITY;
|
||||||
|
|
||||||
half2 s2 = make_half2(0.0, 0.0);
|
for (int j = 0; j < Q; ++j) {
|
||||||
|
const float mc = mp[j] ? mp[j][iic + lane_id] : -INFINITY;
|
||||||
|
smc = warp_reduce_max(max(smc, mc));
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
if (smc == -INFINITY) {
|
||||||
for (int i = 0; i < D2/tph; ++i) {
|
continue;
|
||||||
s2 = pq2[hiiw*D2 + tph*i + tiih] * pk2[tph*i + tiih] + s2;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
ss[hiiw*tph + tiih] = __half2half2(s2.x + s2.y);
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
if (tiih == 0) {
|
|
||||||
half2 s = make_half2(0.0, 0.0);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < tph; ++i) {
|
|
||||||
s += ss[hiiw*tph + i];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s = s*scale_h + mv; // s*scale + mv
|
// Q*K^T
|
||||||
|
{
|
||||||
|
half16x16_a mq{};
|
||||||
|
half16x16_b mk{};
|
||||||
|
half16x16_acc mqk{};
|
||||||
|
|
||||||
half2 m = M;
|
for (int cc = 0; cc < C/16; ++cc) {
|
||||||
|
nvcuda::wmma::fill_fragment(mqk, 0); // re fetch
|
||||||
|
|
||||||
M = __hmax2(M, s);
|
const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
half2 ms = h2exp(m - M);
|
for(int i = 0; i < D16;i ++) {
|
||||||
half2 vs = h2exp(s - M);
|
nvcuda::wmma::load_matrix_sync(mq, pq + i*16, T);
|
||||||
|
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
||||||
|
nvcuda::wmma::mma_sync(mqk, mq, mk, mqk);
|
||||||
|
}
|
||||||
|
|
||||||
S = S*ms + vs;
|
nvcuda::wmma::store_matrix_sync(ss + 16*cc, mqk, T, nvcuda::wmma::mem_col_major);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ss[2*hiiw + 0] = ms;
|
// online softmax
|
||||||
ss[2*hiiw + 1] = vs;
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
const int64_t p = lane_id;
|
||||||
|
|
||||||
|
const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]);
|
||||||
|
|
||||||
|
half m = M[j];
|
||||||
|
|
||||||
|
M[j] = warp_reduce_max(__hmax(M[j], s));
|
||||||
|
|
||||||
|
const half ms = __hisinf(m) == -1 ? 0.0 : hexp(m - M[j]);
|
||||||
|
const half vs = __hisinf(s) == -1 ? 0.0 : hexp(s - M[j]);
|
||||||
|
|
||||||
|
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
||||||
|
|
||||||
|
ss[j*T + p] = vs;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// (Q*K^T)*V
|
||||||
|
{
|
||||||
|
half16x16_acc mqkv{};
|
||||||
|
half16x16_a mqk{};
|
||||||
|
half16x16_b mv{};
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < D16; ++i) {
|
||||||
|
nvcuda::wmma::fill_fragment(mqkv, 0);
|
||||||
|
|
||||||
|
for (int cc = 0; cc < C/16; ++cc) {
|
||||||
|
const half * pv = (const half *) ((const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
nvcuda::wmma::load_matrix_sync(mqk, ss + cc*16, T);
|
||||||
|
nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half));
|
||||||
|
|
||||||
|
nvcuda::wmma::mma_sync(mqkv, mqk, mv, mqkv);
|
||||||
|
}
|
||||||
|
|
||||||
|
nvcuda::wmma::store_matrix_sync(ps + i*16, mqkv, T, nvcuda::wmma::mem_col_major);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
if (lane_id == 0) {
|
||||||
half2 ms = ss[2*hiiw + 0];
|
ss[j*T + 0] = S[j];
|
||||||
half2 vs = ss[2*hiiw + 1];
|
ss[j*T + 1] = M[j];
|
||||||
|
}
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < D2/tph; ++i) {
|
|
||||||
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms + pv2[tph*i + tiih]*vs;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tiih == 0) {
|
|
||||||
ss[2*hiiw + 0] = S;
|
|
||||||
ss[2*hiiw + 1] = M;
|
|
||||||
}
|
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// reduce the warps
|
// reduce the warps
|
||||||
|
// TODO: try parallel reduce
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int sg = 1; sg < nwraps; ++sg) {
|
half S = 0.0;
|
||||||
half2 S0 = ss[ 2*hiiw + 0];
|
half M = -INFINITY;
|
||||||
half2 S1 = ss[sg*(R*D + 32) + 2*hiiw + 0];
|
|
||||||
|
|
||||||
half2 M0 = ss[ 2*hiiw + 1];
|
for (int64_t sg = 1; sg < n_warps; ++sg) {
|
||||||
half2 M1 = ss[sg*(R*D + 32) + 2*hiiw + 1];
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
const half S0 = ss[j*T + 0];
|
||||||
|
const half S1 = ss[j*T + sg*(D + 1*C) + 0];
|
||||||
|
|
||||||
M = __hmax2(M0, M1);
|
const half M0 = ss[j*T + 1];
|
||||||
|
const half M1 = ss[j*T + sg*(D + 1*C) + 1];
|
||||||
|
|
||||||
half2 ms0 = h2exp(M0 - M);
|
M = __hmax(M0, M1);
|
||||||
half2 ms1 = h2exp(M1 - M);
|
|
||||||
|
|
||||||
S = S0*ms0 + S1*ms1;
|
const half ms0 = hexp(M0 - M);
|
||||||
|
const half ms1 = hexp(M1 - M);
|
||||||
|
|
||||||
if (tiih == 0) {
|
S = S0*ms0 + S1*ms1;
|
||||||
ss[2*hiiw + 0] = S;
|
|
||||||
ss[2*hiiw + 1] = M;
|
if (lane_id == 0) {
|
||||||
|
ss[j*T + 0] = S;
|
||||||
|
ss[j*T + 1] = M;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < L2; ++i) {
|
||||||
|
ps2[j*T2 + N4*i + lane_id] = ps2[j*T2 + N4*i + lane_id]*__half2half2(ms0) + ps2[j*T2 + sg*(D + 1*C)/4 + N4*i + lane_id]*__half2half2(ms1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < D2/tph; ++i) {
|
|
||||||
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih]*ms1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < D2/tph; ++i) {
|
|
||||||
ps2[hiiw*D2 + tph*i + tiih] = __h2div(ps2[hiiw*D2 + tph*i + tiih], S);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// dst indices
|
|
||||||
const int i1 = iq1;
|
|
||||||
const int i2 = iq2;
|
|
||||||
const int i3 = iq3;
|
|
||||||
|
|
||||||
float2 * dst2 = (float2 *) kqv;
|
float2 * dst2 = (float2 *) kqv;
|
||||||
|
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int i = 0; i < D2/tph; ++i) {
|
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]);
|
half2 S = __half2half2(ss[j*T + 0]);
|
||||||
|
|
||||||
|
for (int i = 0; i < L2; ++i) {
|
||||||
|
dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + N4*i + lane_id] = __half22float2(ps2[j*T2 + N4*i + lane_id]/S);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -10300,7 +10391,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
const int nwarps = Q->ne[1] < 4 ? 4 : 2;
|
const int nwarps = Q->ne[1] < 4 ? 4 : 2;
|
||||||
const int nqpb = 2; // queries per block
|
const int nqpb = 16; // queries per block
|
||||||
const int ncpw = 32; // cache values per warp (does not work for other values)
|
const int ncpw = 32; // cache values per warp (does not work for other values)
|
||||||
|
|
||||||
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||||
|
@ -10311,7 +10402,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
switch (Q->ne[0])
|
switch (Q->ne[0])
|
||||||
{
|
{
|
||||||
case 64:
|
case 64:
|
||||||
flash_attn_ext_f16<64, 8, 32>
|
flash_attn_ext_f16<64, 16, 32>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
@ -10328,7 +10419,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
case 80:
|
case 80:
|
||||||
flash_attn_ext_f16<80, 8, 32>
|
flash_attn_ext_f16<80, 16, 32>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
@ -10345,7 +10436,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
flash_attn_ext_f16<128, 8, 32>
|
flash_attn_ext_f16<128, 16, 32>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue