update implementation
This commit is contained in:
parent
78da3387a8
commit
6e7cb0eeaf
1 changed files with 52 additions and 48 deletions
100
ggml-cuda.cu
100
ggml-cuda.cu
|
@ -6113,7 +6113,7 @@ static __global__ void flash_attn_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
// based on metal version
|
// based on metal version
|
||||||
template<int D, int R> // D head size, R rows per block
|
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
|
||||||
static __global__ void flash_attn_ext_f16(
|
static __global__ void flash_attn_ext_f16(
|
||||||
const char* __restrict__ q,
|
const char* __restrict__ q,
|
||||||
const char* __restrict__ k,
|
const char* __restrict__ k,
|
||||||
|
@ -6141,62 +6141,64 @@ static __global__ void flash_attn_ext_f16(
|
||||||
int ne1,
|
int ne1,
|
||||||
int ne2,
|
int ne2,
|
||||||
int ne3) {
|
int ne3) {
|
||||||
int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.y;
|
||||||
int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x;
|
||||||
|
|
||||||
const int nwraps = blockDim.y; // number of warps
|
const int n_warps = blockDim.y; // number of warps
|
||||||
const int tph = WARP_SIZE / R; // threads per head
|
|
||||||
const int iq3 = blockIdx.z;
|
const int iq3 = blockIdx.z;
|
||||||
const int iq2 = blockIdx.y * R + lane_id / tph;
|
const int iq2 = blockIdx.y;
|
||||||
const int iq1 = blockIdx.x;
|
const int iq1 = blockIdx.x * Q;
|
||||||
|
|
||||||
if(iq2 >= ne02) {
|
const int D2 = D/2;
|
||||||
return;
|
const int N4 = WARP_SIZE;
|
||||||
}
|
const int L2 = (D2 + N4 - 1)/N4;
|
||||||
|
const int D8 = D/8;
|
||||||
|
|
||||||
// broadcast
|
const int T = D + n_warps*(D + 1*C); // shared memory size per query in half
|
||||||
const int rk2 = ne02 / ne12;
|
const int T2 = T/2; // shared memory size per query in half2
|
||||||
const int rk3 = ne03 / ne13;
|
|
||||||
// assume the same K and V shape
|
|
||||||
// const int rv2 = ne02 / ne12;
|
|
||||||
// const int rv3 = ne03 / ne13;
|
|
||||||
|
|
||||||
// kv indices
|
|
||||||
const int ik2 = iq2 / rk2;
|
|
||||||
const int ik3 = iq3 / rk3;
|
|
||||||
// const int iv2 = iq2 / rv2;
|
|
||||||
// const int iv3 = iq3 / rv3;
|
|
||||||
|
|
||||||
const half2 scale_h = __half2half2(__float2half(scale));
|
const half2 scale_h = __half2half2(__float2half(scale));
|
||||||
|
|
||||||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
|
||||||
|
|
||||||
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
|
|
||||||
|
|
||||||
extern __shared__ char data_flash_attn_shmem[];
|
extern __shared__ char data_flash_attn_shmem[];
|
||||||
|
|
||||||
half2* pq2 = (half2*)data_flash_attn_shmem;
|
half * pq = (half *) (data_flash_attn_shmem + 0*D);
|
||||||
half2* ps2 = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 1*R*D);
|
half2 * pq2 = (half2 *) (data_flash_attn_shmem + 0*D);
|
||||||
half2* ss = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 2*R*D);
|
half * ps = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
|
||||||
|
half2 * ps2 = (half2 *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 1*D);
|
||||||
|
half * ss = (half *) (data_flash_attn_shmem + warp_id*(D + 1*C) + 2*D);
|
||||||
|
|
||||||
const int tiih = lane_id % tph; // thread index in head
|
for (int i = 0; i < L2; ++i) {
|
||||||
const int hiiw = lane_id / tph; // head index in warp
|
// load heads from Q to shared memory
|
||||||
|
for (int j = warp_id; j < Q; j += n_warps) {
|
||||||
const int D2 = D / 2; // number of half2 to store head_dim row
|
if (iq1 + j < ne01) {
|
||||||
|
pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id];
|
||||||
// load R heads from Q to shared memory
|
} else {
|
||||||
for (int i = 0; i < D2/tph; ++i) {
|
pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
|
||||||
if (warp_id == 0) {
|
}
|
||||||
pq2[hiiw*D2 + tph*i + tiih] = ((half2*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ps2[hiiw*D2 + tph*i + tiih] = make_half2(0.0, 0.0);
|
// zero out shared memory
|
||||||
|
for (int j = 0; j < Q; ++j) {
|
||||||
|
ps2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (lane_id < C) {
|
||||||
|
for (int j = 0; j < Q; ++j) {
|
||||||
|
ss[j*T + 0 + lane_id] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
half2 S = make_half2(0.0, 0.0);
|
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
||||||
|
|
||||||
|
half S[8] = { 0.0 };
|
||||||
|
#if 0
|
||||||
half2 M = make_half2(-INFINITY, -INFINITY);
|
half2 M = make_half2(-INFINITY, -INFINITY);
|
||||||
|
|
||||||
|
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
|
||||||
|
|
||||||
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
|
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
|
||||||
const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0);
|
const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0);
|
||||||
if (__hisinf(mv.x) == -1) { // mv == -INFINITY
|
if (__hisinf(mv.x) == -1) { // mv == -INFINITY
|
||||||
|
@ -6302,6 +6304,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]);
|
dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -10296,18 +10299,19 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
float scale;
|
float scale;
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
const int nwarps = 32;
|
const int nwarps = Q->ne[1] < 4 ? 4 : 2;
|
||||||
const int nhpw = 2; // heads per warp
|
const int nqpb = 2; // queries per block
|
||||||
|
const int ncpw = 32; // cache values per warp (does not work for other values)
|
||||||
|
|
||||||
dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1) / nhpw, Q->ne[3]);
|
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||||
dim3 block_dim(32 * nwarps, 1, 1);
|
dim3 block_dim(32, nwarps, 1);
|
||||||
|
|
||||||
int shmem = (nhpw*Q->ne[0]*2 + nwarps*(nhpw*Q->ne[0] + 32)) * (sizeof(float)/2);
|
int shmem = nqpb*(Q->ne[0] + nwarps*(Q->ne[0] + 1*ncpw))*(sizeof(float)/2);
|
||||||
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]);
|
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]);
|
||||||
switch (Q->ne[0])
|
switch (Q->ne[0])
|
||||||
{
|
{
|
||||||
case 64:
|
case 64:
|
||||||
flash_attn_ext_f16<64, 2>
|
flash_attn_ext_f16<64, 8, 32>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
@ -10324,7 +10328,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
case 80:
|
case 80:
|
||||||
flash_attn_ext_f16<80, 2>
|
flash_attn_ext_f16<80, 8, 32>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
@ -10341,7 +10345,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
flash_attn_ext_f16<128, 2>
|
flash_attn_ext_f16<128, 8, 32>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue