update implementation

This commit is contained in:
FSSRepo 2024-01-25 11:04:51 -05:00
parent 78da3387a8
commit 6e7cb0eeaf

View file

@ -6113,7 +6113,7 @@ static __global__ void flash_attn_f32(
}
// based on metal version
template<int D, int R> // D head size, R rows per block
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
static __global__ void flash_attn_ext_f16(
const char* __restrict__ q,
const char* __restrict__ k,
@ -6141,62 +6141,64 @@ static __global__ void flash_attn_ext_f16(
int ne1,
int ne2,
int ne3) {
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = threadIdx.y;
const int lane_id = threadIdx.x;
const int nwraps = blockDim.y; // number of warps
const int tph = WARP_SIZE / R; // threads per head
const int n_warps = blockDim.y; // number of warps
const int iq3 = blockIdx.z;
const int iq2 = blockIdx.y * R + lane_id / tph;
const int iq1 = blockIdx.x;
const int iq2 = blockIdx.y;
const int iq1 = blockIdx.x * Q;
if(iq2 >= ne02) {
return;
}
const int D2 = D/2;
const int N4 = WARP_SIZE;
const int L2 = (D2 + N4 - 1)/N4;
const int D8 = D/8;
// broadcast
const int rk2 = ne02 / ne12;
const int rk3 = ne03 / ne13;
// assume the same K and V shape
// const int rv2 = ne02 / ne12;
// const int rv3 = ne03 / ne13;
// kv indices
const int ik2 = iq2 / rk2;
const int ik3 = iq3 / rk3;
// const int iv2 = iq2 / rv2;
// const int iv3 = iq3 / rv3;
const int T = D + n_warps*(D + 1*C); // shared memory size per query in half
const int T2 = T/2; // shared memory size per query in half2
const half2 scale_h = __half2half2(__float2half(scale));
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
extern __shared__ char data_flash_attn_shmem[];
half2* pq2 = (half2*)data_flash_attn_shmem;
half2* ps2 = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 1*R*D);
half2* ss = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 2*R*D);
half * pq = (half *) (data_flash_attn_shmem + 0*D);
half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D);
half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D);
const int tiih = lane_id % tph; // thread index in head
const int hiiw = lane_id / tph; // head index in warp
const int D2 = D / 2; // number of half2 to store head_dim row
// load R heads from Q to shared memory
for (int i = 0; i < D2/tph; ++i) {
if (warp_id == 0) {
pq2[hiiw*D2 + tph*i + tiih] = ((half2*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
for (int i = 0; i < L2; ++i) {
// load heads from Q to shared memory
for (int j = warp_id; j < Q; j += n_warps) {
if (iq1 + j < ne01) {
pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id];
} else {
pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
}
}
ps2[hiiw*D2 + tph*i + tiih] = make_half2(0.0, 0.0);
// zero out shared memory
for (int j = 0; j < Q; ++j) {
ps2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
}
}
if (lane_id < C) {
for (int j = 0; j < Q; ++j) {
ss[j*T + 0 + lane_id] = 0.0;
}
}
__syncthreads();
half2 S = make_half2(0.0, 0.0);
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
half S[8] = { 0.0 };
#if 0
half2 M = make_half2(-INFINITY, -INFINITY);
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0);
if (__hisinf(mv.x) == -1) { // mv == -INFINITY
@ -6302,6 +6304,7 @@ static __global__ void flash_attn_ext_f16(
dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]);
}
}
#endif
}
@ -10296,18 +10299,19 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
const int nwarps = 32;
const int nhpw = 2; // heads per warp
const int nwarps = Q->ne[1] < 4 ? 4 : 2;
const int nqpb = 2; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values)
dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1) / nhpw, Q->ne[3]);
dim3 block_dim(32 * nwarps, 1, 1);
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1);
int shmem = (nhpw*Q->ne[0]*2 + nwarps*(nhpw*Q->ne[0] + 32)) * (sizeof(float)/2);
int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2);
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]);
switch (Q->ne[0])
{
case 64:
flash_attn_ext_f16<64, 2>
flash_attn_ext_f16<64, 8, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
@ -10324,7 +10328,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 80:
flash_attn_ext_f16<80, 2>
flash_attn_ext_f16<80, 8, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key
@ -10341,7 +10345,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
);
break;
case 128:
flash_attn_ext_f16<128, 2>
flash_attn_ext_f16<128, 8, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key