cuda: port metal version flash_attn_ext

This commit is contained in:
FSSRepo 2024-01-23 16:42:53 -05:00
parent a689b02ad3
commit 6374bc5779

View file

@ -937,6 +937,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
@ -6106,6 +6107,211 @@ static __global__ void flash_attn_f32(
}
}
struct __align__(8) half4 {
half x;
half y;
half z;
half w;
};
// based on metal version
template<int D, int R> // head size, rows per block
static __global__ void flash_attn_ext_f16(
const char* __restrict__ q,
const char* __restrict__ k,
const char* __restrict__ v,
const char* __restrict__ mask,
float* __restrict__ kqv,
float scale,
int ne00,
int ne01,
int ne02,
int ne03,
int ne10,
int ne11,
int ne12,
int ne13,
int ne31,
int nb31,
int nb01,
int nb02,
int nb03,
int nb11,
int nb12,
int nb13,
int ne0,
int ne1,
int ne2,
int ne3) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
const int nwraps = blockDim.y; // number of warps
const int tph = WARP_SIZE / R; // threads per head
const int iq3 = blockIdx.z;
const int iq2 = blockIdx.y * R + lane_id / tph;
const int iq1 = blockIdx.x;
if(iq2 >= ne02) {
return;
}
// broadcast
const int rk2 = ne02 / ne12;
const int rk3 = ne03 / ne13;
// assume the same K and V shape
// const int rv2 = ne02 / ne12;
// const int rv3 = ne03 / ne13;
// kv indices
const int ik2 = iq2 / rk2;
const int ik3 = iq3 / rk3;
const int iv2 = iq2 / rv2;
const int iv3 = iq3 / rv3;
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
const float * mp = mask ? mask + (ir % ne31)*nb31 : nullptr;
extern __shared__ char shmem__[];
half4* pq4 = (half4*)shmem__;
half4* ps4 = (half4*)(shmem__ + warp_id * (R * D + 32) + 1*R*D);
half* ss = (half *)(shmem__ + warp_id * (R * D + 32) + 2*R*D);
const int tiih = lane_id % tph; // thread index in head
const int hiisg = lane_id / tph; // head index in warp
const int D4 = D/4;
// load R heads from Q to shared memory
for (int64_t i = 0; i < D4/tph; ++i) {
if (warp_id == 0) {
pq4[hiisg*D4 + tph*i + tiih] = (const half4*)((const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03))[tph*i + tiih];
}
ps4[hiisg*D4 + tph*i + tiih] = 0.0h;
}
__syncthreads();
half S = 0.0h;
half M = -INFINITY;
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
const half mv = mp ? mp[ic] : 0.0h;
if (mv == -INFINITY) {
continue;
}
const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + iv2*nb12 + iv3*nb13)); // assumes V same shape of K
half4 s4 = 0.0h;
#pragma unroll
for (int i = 0; i < D4/tph; ++i) {
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
}
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);
__syncthreads();
if (tiih == 0) {
half s = 0.0h;
#pragma unroll
for (int i = 0; i < tph; ++i) {
s += ss[hiisg*tph + i];
}
s = s*scale + mv;
const half m = M;
M = max(M, s);
const half ms = exp(m - M);
const half vs = exp(s - M);
S = S*ms + vs;
ss[2*hiisg + 0] = ms;
ss[2*hiisg + 1] = vs;
}
__syncthreads();
const half ms = ss[2*hiisg + 0];
const half vs = ss[2*hiisg + 1];
#pragma unroll
for (int i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs;
}
}
if (tiih == 0) {
ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
}
__syncthreads();
// reduce the warps
if (warp_id == 0) {
for (int sg = 1; sg < nwraps; ++sg) {
const half S0 = ss[ 2*hiisg + 0];
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];
const half M0 = ss[ 2*hiisg + 1];
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
M = max(M0, M1);
const half ms0 = exp(M0 - M);
const half ms1 = exp(M1 - M);
S = S0*ms0 + S1*ms1;
if (tiih == 0) {
ss[2*hiisg + 0] = S;
ss[2*hiisg + 1] = M;
}
for (int i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1;
}
}
for (int i = 0; i < D4/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S;
}
}
__syncthreads();
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
float4 * dst4 = (float4 *) kqv;
if (warp_id == 0) {
for (int i = 0; i < D4/tph; ++i) {
float4 dst_ =
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih];
half4 src_ = ps4[hiisg*D4 + tph*i + tiih];
dst_.x = __half2float(src_.x);
dst_.y = __half2float(src_.y);
dst_.z = __half2float(src_.z);
dst_.w = __half2float(src_.w);
}
}
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
@ -10071,6 +10277,98 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c
KQ_scale, d_head, sequence_length, num_heads, main_stream);
}
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
GGML_ASSERT(Q->type == GGML_TYPE_F16);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(mask->type == GGML_TYPE_F32);
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU);
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra;
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
const int nwarps = 32;
const int nhpw = 2; // heads per warp
dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1)/(nhpw), Q->ne[3]);
dim3 block_dim(32, nwarps, 1);
int shmem = (nhpw*Q->ne[0] + nwarps*(nhpw*Q->ne[0] + 32))*(sizeof(float)/2);
switch (Q->ne[0])
{
case 64:
flash_attn_ext_f16<64, 2>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
(const char *) src3_extra->data_device[g_main_device], // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask->ne[1], mask->nb[1],
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 80:
flash_attn_ext_f16<80, 2>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
(const char *) src3_extra->data_device[g_main_device], // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask->ne[1], mask->nb[1],
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
case 128:
flash_attn_ext_f16<128, 2>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
(const char *) src2_extra->data_device[g_main_device], // Value
(const char *) src3_extra->data_device[g_main_device], // Mask
(float *) dst_extra->data_device[g_main_device], // dst
scale,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask->ne[1], mask->nb[1],
Q->nb[1], Q->nb[2], Q->nb[3],
K->nb[1], K->nb[2], K->nb[3],
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
);
break;
default:
break;
}
}
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
}
@ -10341,6 +10639,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
break;
case GGML_OP_FLASH_ATTN:
break;
case GGML_OP_FLASH_ATTN_EXT:
break;
default:
return false;
}
@ -10357,6 +10657,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
}
if(tensor->op == GGML_OP_FLASH_ATTN) {
ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) {
ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
} else {
func(tensor->src[0], tensor->src[1], tensor);
}
@ -11175,6 +11477,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
return true;
default:
return false;