use half2 instead half4

This commit is contained in:
FSSRepo 2024-01-24 16:41:57 -05:00
parent 6416821499
commit 972c2adc15

View file

@ -5992,7 +5992,7 @@ static __global__ void im2col_f32_f16(
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 #define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
template<int block_size, int k_seq_len> template<int block_size, int k_seq_len, int k_head_dim>
static __global__ void flash_attn_f32( static __global__ void flash_attn_f32(
const float* __restrict__ q, const float* __restrict__ q,
const float* __restrict__ k, const float* __restrict__ k,
@ -6004,9 +6004,9 @@ static __global__ void flash_attn_f32(
const int head_size = head_dim * seq_len; const int head_size = head_dim * seq_len;
const int s = blockIdx.x % seq_len; const int s = blockIdx.x % seq_len;
extern __shared__ char shmem__[]; extern __shared__ char flash_attn_shmem_f32[];
float* S = (float*)shmem__; float* S = (float*)flash_attn_shmem_f32;
float* warp_data = (float*)(shmem__ + seq_len * sizeof(float)); float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float));
// QK^T // QK^T
#pragma unroll #pragma unroll
@ -6019,11 +6019,11 @@ static __global__ void flash_attn_f32(
const int key_offset = is * head_dim + head * head_size; const int key_offset = is * head_dim + head * head_size;
const int query_offset = s * head_dim + head * head_size; const int query_offset = s * head_dim + head * head_size;
S[is] = 0.0f; float tmp = 0.0f;
for(int d = 0; d < head_dim; d++) { for(int d = 0; d < head_dim; d++) {
S[is] += k[key_offset + d] * q[query_offset + d]; tmp += k[key_offset + d] * q[query_offset + d];
} }
S[is] *= kq_scale; S[is] = tmp * kq_scale;
} }
__syncthreads(); __syncthreads();
@ -6060,9 +6060,9 @@ static __global__ void flash_attn_f32(
if(is >= seq_len) { if(is >= seq_len) {
break; break;
} }
float tmp = expf(S[is] - max_val);
S[is] = expf(S[is] - max_val); sum += tmp;
sum += S[is]; S[is] = tmp;
} }
__syncthreads(); __syncthreads();
@ -6091,7 +6091,12 @@ static __global__ void flash_attn_f32(
__syncthreads(); __syncthreads();
// softmax(QK^T)V // softmax(QK^T)V
for (int d = threadIdx.x; d < head_dim; d += block_size) { #pragma unroll
for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) {
const int d = threadIdx.x + d0;
if(d >= head_dim) {
break;
}
const int dst_index = d + s * head_dim + head * head_size; const int dst_index = d + s * head_dim + head * head_size;
const int value_offset = d * seq_len + head * head_size; const int value_offset = d * seq_len + head * head_size;
@ -6107,51 +6112,8 @@ static __global__ void flash_attn_f32(
} }
} }
struct __align__(8) half4 {
half x;
half y;
half z;
half w;
};
__device__ half4 make_half4(half x) {
half4 t;
t.x = x; t.y = x; t.z = x; t.w = x;
return t;
}
__device__ half4 __h4fma(half4 a, half b, half4 c) {
half4 t;
t.x = __hfma(a.x, b, c.x); t.y = __hfma(a.y, b, c.y); t.z = __hfma(a.z, b, c.z); t.w = __hfma(a.w, b, c.w);
return t;
}
__device__ half4 __h4fma(half4 a, half4 b, half4 c) {
half4 t;
t.x = __hfma(a.x, b.x, c.x); t.y = __hfma(a.y, b.y, c.y); t.z = __hfma(a.z, b.z, c.z); t.w = __hfma(a.w, b.w, c.w);
return t;
}
__device__ half4 __h4mul(half4 a, half b) {
half4 t;
t.x = __hmul(a.x, b); t.y = __hmul(a.y, b); t.z =__hmul(a.z, b); t.w =__hmul(a.w, b);
return t;
}
__device__ half4 __h4mul(half4 a, half4 b) {
half4 t;
t.x = __hmul(a.x, b.x); t.y = __hmul(a.y, b.y); t.z =__hmul(a.z, b.z); t.w =__hmul(a.w, b.w);
return t;
}
__device__ half4 __h4div(half4 a, half b) {
half4 t;
t.x = __hdiv(a.x, b); t.y = __hdiv(a.y, b); t.z =__hdiv(a.z, b); t.w =__hdiv(a.w, b);
return t;
}
// based on metal version // based on metal version
template<int D, int R> // head size, rows per block template<int D, int R> // D head size, R rows per block
static __global__ void flash_attn_ext_f16( static __global__ void flash_attn_ext_f16(
const char* __restrict__ q, const char* __restrict__ q,
const char* __restrict__ k, const char* __restrict__ k,
@ -6205,91 +6167,93 @@ static __global__ void flash_attn_ext_f16(
// const int iv2 = iq2 / rv2; // const int iv2 = iq2 / rv2;
// const int iv3 = iq3 / rv3; // const int iv3 = iq3 / rv3;
const half2 scale_h = __half2half2(__float2half(scale));
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr; const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
extern __shared__ char shmem__[]; extern __shared__ char data_flash_attn_shmem[];
half4* pq4 = (half4*)shmem__; half2* pq2 = (half2*)data_flash_attn_shmem;
half4* ps4 = (half4*)(shmem__ + warp_id * (R * D + 32) + 1*R*D); half2* ps2 = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 1*R*D);
half* ss = (half *)(shmem__ + warp_id * (R * D + 32) + 2*R*D); half2* ss = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 2*R*D);
const int tiih = lane_id % tph; // thread index in head const int tiih = lane_id % tph; // thread index in head
const int hiisg = lane_id / tph; // head index in warp const int hiiw = lane_id / tph; // head index in warp
const int D4 = D/4; const int D2 = D / 2; // number of half2 to store head_dim row
// load R heads from Q to shared memory // load R heads from Q to shared memory
for (int64_t i = 0; i < D4/tph; ++i) { for (int i = 0; i < D2/tph; ++i) {
if (warp_id == 0) { if (warp_id == 0) {
pq4[hiisg*D4 + tph*i + tiih] = ((half4*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih]; pq2[hiiw*D2 + tph*i + tiih] = ((half2*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
} }
ps4[hiisg*D4 + tph*i + tiih] = make_half4(0.0); ps2[hiiw*D2 + tph*i + tiih] = make_half2(0.0, 0.0);
} }
__syncthreads(); __syncthreads();
half S(0.0); half2 S = make_half2(0.0, 0.0);
half M(-INFINITY); half2 M = make_half2(-INFINITY, -INFINITY);
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) { for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
const half mv = mp ? mp[ic] : 0.0; const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0);
if (__hisinf(mv) == -1) { // mv == -INFINITY if (__hisinf(mv.x) == -1) { // mv == -INFINITY
continue; continue;
} }
const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); half2 * pk2 = (half2 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K half2 * pv2 = (half2 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K
half4 s4 = make_half4(0.0); half2 s2 = make_half2(0.0, 0.0);
#pragma unroll #pragma unroll
for (int i = 0; i < D4/tph; ++i) { for (int i = 0; i < D2/tph; ++i) {
s4 = __h4fma(pq4[hiisg*D4 + tph*i + tiih], pk4[tph*i + tiih], s4); s2 = pq2[hiiw*D2 + tph*i + tiih] * pk2[tph*i + tiih] + s2;
} }
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w); ss[hiiw*tph + tiih] = __half2half2(s2.x + s2.y);
__syncthreads(); __syncthreads();
if (tiih == 0) { if (tiih == 0) {
half s = 0.0; half2 s = make_half2(0.0, 0.0);
#pragma unroll #pragma unroll
for (int i = 0; i < tph; ++i) { for (int i = 0; i < tph; ++i) {
s += ss[hiisg*tph + i]; s += ss[hiiw*tph + i];
} }
s = __hfma(s, __float2half(scale), mv); // s*scale + mv s = s * scale_h + mv; // s*scale + mv
const half m = M; half2 m = M;
M = __hmax(M, s); M = __hmax2(M, s);
const half ms = hexp(m - M); half2 ms = h2exp(m - M);
const half vs = hexp(s - M); half2 vs = h2exp(s - M);
S = __hfma(S, ms, vs); S = S * ms + vs;
ss[2*hiisg + 0] = ms; ss[2*hiiw + 0] = ms;
ss[2*hiisg + 1] = vs; ss[2*hiiw + 1] = vs;
} }
__syncthreads(); __syncthreads();
const half ms = ss[2*hiisg + 0]; half2 ms = ss[2*hiiw + 0];
const half vs = ss[2*hiisg + 1]; half2 vs = ss[2*hiiw + 1];
#pragma unroll #pragma unroll
for (int i = 0; i < D4/tph; ++i) { for (int i = 0; i < D2/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms, __h4mul(pv4[tph*i + tiih], vs)); ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms + pv2[tph*i + tiih] * vs;
} }
} }
if (tiih == 0) { if (tiih == 0) {
ss[2*hiisg + 0] = S; ss[2*hiiw + 0] = S;
ss[2*hiisg + 1] = M; ss[2*hiiw + 1] = M;
} }
__syncthreads(); __syncthreads();
@ -6297,31 +6261,31 @@ static __global__ void flash_attn_ext_f16(
// reduce the warps // reduce the warps
if (warp_id == 0) { if (warp_id == 0) {
for (int sg = 1; sg < nwraps; ++sg) { for (int sg = 1; sg < nwraps; ++sg) {
const half S0 = ss[ 2*hiisg + 0]; half2 S0 = ss[ 2*hiiw + 0];
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0]; half2 S1 = ss[sg*(R*D + 32) + 2*hiiw + 0];
const half M0 = ss[ 2*hiisg + 1]; half2 M0 = ss[ 2*hiiw + 1];
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1]; half2 M1 = ss[sg*(R*D + 32) + 2*hiiw + 1];
M = __hmax(M0, M1); M = __hmax2(M0, M1);
const half ms0 = hexp(M0 - M); half2 ms0 = h2exp(M0 - M);
const half ms1 = hexp(M1 - M); half2 ms1 = h2exp(M1 - M);
S = __hfma(S0, ms0, __hmul(S1, ms1)); S = S0 * ms0 + S1 * ms1;
if (tiih == 0) { if (tiih == 0) {
ss[2*hiisg + 0] = S; ss[2*hiiw + 0] = S;
ss[2*hiisg + 1] = M; ss[2*hiiw + 1] = M;
} }
for (int i = 0; i < D4/tph; ++i) { for (int i = 0; i < D2/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms0, __h4mul(ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih], ms1)); ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih] * ms1;
} }
} }
for (int i = 0; i < D4/tph; ++i) { for (int i = 0; i < D2/tph; ++i) {
ps4[hiisg*D4 + tph*i + tiih] = __h4div(ps4[hiisg*D4 + tph*i + tiih], S); ps2[hiiw*D2 + tph*i + tiih] = __h2div(ps2[hiiw*D2 + tph*i + tiih], S);
} }
} }
@ -6332,17 +6296,10 @@ static __global__ void flash_attn_ext_f16(
const int i2 = iq2; const int i2 = iq2;
const int i3 = iq3; const int i3 = iq3;
float4 * dst4 = (float4 *) kqv; float2 * dst2 = (float2 *) kqv;
if (warp_id == 0) { if (warp_id == 0) {
for (int i = 0; i < D4/tph; ++i) { for (int i = 0; i < D2/tph; ++i) {
float4 dst_ = dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]);
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih];
half4 src_ = ps4[hiisg*D4 + tph*i + tiih];
dst_.x = __half2float(src_.x);
dst_.y = __half2float(src_.y);
dst_.z = __half2float(src_.z);
dst_.w = __half2float(src_.w);
} }
} }
} }
@ -7741,7 +7698,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
int num_blocks = num_heads * seq_len; int num_blocks = num_heads * seq_len;
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>( flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024, 64><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
q, k, v, dst, kq_scale, d_head, seq_len, num_heads); q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
} }
@ -10342,11 +10299,11 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
const int nwarps = 32; const int nwarps = 32;
const int nhpw = 2; // heads per warp const int nhpw = 2; // heads per warp
dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1)/(nhpw), Q->ne[3]); dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1) / nhpw, Q->ne[3]);
dim3 block_dim(32, nwarps, 1); dim3 block_dim(32 * nwarps, 1, 1);
int shmem = (nhpw*Q->ne[0] + nwarps*(nhpw*Q->ne[0] + 32))*(sizeof(float)/2);
int shmem = (nhpw*Q->ne[0]*2 + nwarps*(nhpw*Q->ne[0] + 32)) * (sizeof(float)/2);
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]);
switch (Q->ne[0]) switch (Q->ne[0])
{ {
case 64: case 64: