update implementation

This commit is contained in:
FSSRepo 2024-01-25 11:04:51 -05:00
parent 78da3387a8
commit 6e7cb0eeaf

View file

@ -6113,7 +6113,7 @@ static __global__ void flash_attn_f32(
} }
// based on metal version // based on metal version
template<int D, int R> // D head size, R rows per block template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
static __global__ void flash_attn_ext_f16( static __global__ void flash_attn_ext_f16(
const char* __restrict__ q, const char* __restrict__ q,
const char* __restrict__ k, const char* __restrict__ k,
@ -6141,62 +6141,64 @@ static __global__ void flash_attn_ext_f16(
int ne1, int ne1,
int ne2, int ne2,
int ne3) { int ne3) {
int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.y;
int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = threadIdx.x;
const int nwraps = blockDim.y; // number of warps const int n_warps = blockDim.y; // number of warps
const int tph = WARP_SIZE / R; // threads per head
const int iq3 = blockIdx.z; const int iq3 = blockIdx.z;
const int iq2 = blockIdx.y * R + lane_id / tph; const int iq2 = blockIdx.y;
const int iq1 = blockIdx.x; const int iq1 = blockIdx.x * Q;
if(iq2 >= ne02) { const int D2 = D/2;
return; const int N4 = WARP_SIZE;
} const int L2 = (D2 + N4 - 1)/N4;
const int D8 = D/8;
// broadcast const int T = D + n_warps*(D + 1*C); // shared memory size per query in half
const int rk2 = ne02 / ne12; const int T2 = T/2; // shared memory size per query in half2
const int rk3 = ne03 / ne13;
// assume the same K and V shape
// const int rv2 = ne02 / ne12;
// const int rv3 = ne03 / ne13;
// kv indices
const int ik2 = iq2 / rk2;
const int ik3 = iq3 / rk3;
// const int iv2 = iq2 / rv2;
// const int iv3 = iq3 / rv3;
const half2 scale_h = __half2half2(__float2half(scale)); const half2 scale_h = __half2half2(__float2half(scale));
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
extern __shared__ char data_flash_attn_shmem[]; extern __shared__ char data_flash_attn_shmem[];
half2* pq2 = (half2*)data_flash_attn_shmem; half * pq = (half *) (data_flash_attn_shmem + 0*D);
half2* ps2 = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 1*R*D); half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D);
half2* ss = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 2*R*D); half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D);
const int tiih = lane_id % tph; // thread index in head for (int i = 0; i < L2; ++i) {
const int hiiw = lane_id / tph; // head index in warp // load heads from Q to shared memory
for (int j = warp_id; j < Q; j += n_warps) {
const int D2 = D / 2; // number of half2 to store head_dim row if (iq1 + j < ne01) {
pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id];
// load R heads from Q to shared memory } else {
for (int i = 0; i < D2/tph; ++i) { pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
if (warp_id == 0) { }
pq2[hiiw*D2 + tph*i + tiih] = ((half2*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
} }
ps2[hiiw*D2 + tph*i + tiih] = make_half2(0.0, 0.0); // zero out shared memory
for (int j = 0; j < Q; ++j) {
ps2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
} }
}
if (lane_id < C) {
for (int j = 0; j < Q; ++j) {
ss[j*T + 0 + lane_id] = 0.0;
}
}
__syncthreads(); __syncthreads();
half2 S = make_half2(0.0, 0.0); const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
half S[8] = { 0.0 };
#if 0
half2 M = make_half2(-INFINITY, -INFINITY); half2 M = make_half2(-INFINITY, -INFINITY);
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0); const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0);
if (__hisinf(mv.x) == -1) { // mv == -INFINITY if (__hisinf(mv.x) == -1) { // mv == -INFINITY
@ -6302,6 +6304,7 @@ static __global__ void flash_attn_ext_f16(
dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]); dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]);
} }
} }
#endif
} }
@ -10296,18 +10299,19 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
float scale; float scale;
memcpy(&scale, KQV->op_params, sizeof(float)); memcpy(&scale, KQV->op_params, sizeof(float));
const int nwarps = 32; const int nwarps = Q->ne[1] < 4 ? 4 : 2;
const int nhpw = 2; // heads per warp const int nqpb = 2; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values)
dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1) / nhpw, Q->ne[3]); dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32 * nwarps, 1, 1); dim3 block_dim(32, nwarps, 1);
int shmem = (nhpw*Q->ne[0]*2 + nwarps*(nhpw*Q->ne[0] + 32)) * (sizeof(float)/2); int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2);
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]);
switch (Q->ne[0]) switch (Q->ne[0])
{ {
case 64: case 64:
flash_attn_ext_f16<64, 2> flash_attn_ext_f16<64, 8, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key
@ -10324,7 +10328,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
); );
break; break;
case 80: case 80:
flash_attn_ext_f16<80, 2> flash_attn_ext_f16<80, 8, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key
@ -10341,7 +10345,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
); );
break; break;
case 128: case 128:
flash_attn_ext_f16<128, 2> flash_attn_ext_f16<128, 8, 32>
<<<blocks_num, block_dim, shmem, main_stream>>> ( <<<blocks_num, block_dim, shmem, main_stream>>> (
(const char *) src0_extra->data_device[g_main_device], // Query (const char *) src0_extra->data_device[g_main_device], // Query
(const char *) src1_extra->data_device[g_main_device], // Key (const char *) src1_extra->data_device[g_main_device], // Key