use half2 instead half4
This commit is contained in:
parent
6416821499
commit
972c2adc15
1 changed files with 77 additions and 120 deletions
197
ggml-cuda.cu
197
ggml-cuda.cu
|
@ -5992,7 +5992,7 @@ static __global__ void im2col_f32_f16(
|
||||||
|
|
||||||
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
|
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
|
||||||
|
|
||||||
template<int block_size, int k_seq_len>
|
template<int block_size, int k_seq_len, int k_head_dim>
|
||||||
static __global__ void flash_attn_f32(
|
static __global__ void flash_attn_f32(
|
||||||
const float* __restrict__ q,
|
const float* __restrict__ q,
|
||||||
const float* __restrict__ k,
|
const float* __restrict__ k,
|
||||||
|
@ -6004,9 +6004,9 @@ static __global__ void flash_attn_f32(
|
||||||
const int head_size = head_dim * seq_len;
|
const int head_size = head_dim * seq_len;
|
||||||
const int s = blockIdx.x % seq_len;
|
const int s = blockIdx.x % seq_len;
|
||||||
|
|
||||||
extern __shared__ char shmem__[];
|
extern __shared__ char flash_attn_shmem_f32[];
|
||||||
float* S = (float*)shmem__;
|
float* S = (float*)flash_attn_shmem_f32;
|
||||||
float* warp_data = (float*)(shmem__ + seq_len * sizeof(float));
|
float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float));
|
||||||
|
|
||||||
// QK^T
|
// QK^T
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -6019,11 +6019,11 @@ static __global__ void flash_attn_f32(
|
||||||
const int key_offset = is * head_dim + head * head_size;
|
const int key_offset = is * head_dim + head * head_size;
|
||||||
const int query_offset = s * head_dim + head * head_size;
|
const int query_offset = s * head_dim + head * head_size;
|
||||||
|
|
||||||
S[is] = 0.0f;
|
float tmp = 0.0f;
|
||||||
for(int d = 0; d < head_dim; d++) {
|
for(int d = 0; d < head_dim; d++) {
|
||||||
S[is] += k[key_offset + d] * q[query_offset + d];
|
tmp += k[key_offset + d] * q[query_offset + d];
|
||||||
}
|
}
|
||||||
S[is] *= kq_scale;
|
S[is] = tmp * kq_scale;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -6060,9 +6060,9 @@ static __global__ void flash_attn_f32(
|
||||||
if(is >= seq_len) {
|
if(is >= seq_len) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
float tmp = expf(S[is] - max_val);
|
||||||
S[is] = expf(S[is] - max_val);
|
sum += tmp;
|
||||||
sum += S[is];
|
S[is] = tmp;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -6091,7 +6091,12 @@ static __global__ void flash_attn_f32(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// softmax(QK^T)V
|
// softmax(QK^T)V
|
||||||
for (int d = threadIdx.x; d < head_dim; d += block_size) {
|
#pragma unroll
|
||||||
|
for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) {
|
||||||
|
const int d = threadIdx.x + d0;
|
||||||
|
if(d >= head_dim) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
const int dst_index = d + s * head_dim + head * head_size;
|
const int dst_index = d + s * head_dim + head * head_size;
|
||||||
const int value_offset = d * seq_len + head * head_size;
|
const int value_offset = d * seq_len + head * head_size;
|
||||||
|
|
||||||
|
@ -6107,51 +6112,8 @@ static __global__ void flash_attn_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct __align__(8) half4 {
|
|
||||||
half x;
|
|
||||||
half y;
|
|
||||||
half z;
|
|
||||||
half w;
|
|
||||||
};
|
|
||||||
|
|
||||||
__device__ half4 make_half4(half x) {
|
|
||||||
half4 t;
|
|
||||||
t.x = x; t.y = x; t.z = x; t.w = x;
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ half4 __h4fma(half4 a, half b, half4 c) {
|
|
||||||
half4 t;
|
|
||||||
t.x = __hfma(a.x, b, c.x); t.y = __hfma(a.y, b, c.y); t.z = __hfma(a.z, b, c.z); t.w = __hfma(a.w, b, c.w);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ half4 __h4fma(half4 a, half4 b, half4 c) {
|
|
||||||
half4 t;
|
|
||||||
t.x = __hfma(a.x, b.x, c.x); t.y = __hfma(a.y, b.y, c.y); t.z = __hfma(a.z, b.z, c.z); t.w = __hfma(a.w, b.w, c.w);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ half4 __h4mul(half4 a, half b) {
|
|
||||||
half4 t;
|
|
||||||
t.x = __hmul(a.x, b); t.y = __hmul(a.y, b); t.z =__hmul(a.z, b); t.w =__hmul(a.w, b);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ half4 __h4mul(half4 a, half4 b) {
|
|
||||||
half4 t;
|
|
||||||
t.x = __hmul(a.x, b.x); t.y = __hmul(a.y, b.y); t.z =__hmul(a.z, b.z); t.w =__hmul(a.w, b.w);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ half4 __h4div(half4 a, half b) {
|
|
||||||
half4 t;
|
|
||||||
t.x = __hdiv(a.x, b); t.y = __hdiv(a.y, b); t.z =__hdiv(a.z, b); t.w =__hdiv(a.w, b);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
// based on metal version
|
// based on metal version
|
||||||
template<int D, int R> // head size, rows per block
|
template<int D, int R> // D head size, R rows per block
|
||||||
static __global__ void flash_attn_ext_f16(
|
static __global__ void flash_attn_ext_f16(
|
||||||
const char* __restrict__ q,
|
const char* __restrict__ q,
|
||||||
const char* __restrict__ k,
|
const char* __restrict__ k,
|
||||||
|
@ -6205,91 +6167,93 @@ static __global__ void flash_attn_ext_f16(
|
||||||
// const int iv2 = iq2 / rv2;
|
// const int iv2 = iq2 / rv2;
|
||||||
// const int iv3 = iq3 / rv3;
|
// const int iv3 = iq3 / rv3;
|
||||||
|
|
||||||
|
const half2 scale_h = __half2half2(__float2half(scale));
|
||||||
|
|
||||||
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
||||||
|
|
||||||
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
|
const float * mp = mask ? (const float*)(mask + (ir % ne31)*nb31) : nullptr;
|
||||||
|
|
||||||
extern __shared__ char shmem__[];
|
extern __shared__ char data_flash_attn_shmem[];
|
||||||
|
|
||||||
half4* pq4 = (half4*)shmem__;
|
half2* pq2 = (half2*)data_flash_attn_shmem;
|
||||||
half4* ps4 = (half4*)(shmem__ + warp_id * (R * D + 32) + 1*R*D);
|
half2* ps2 = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 1*R*D);
|
||||||
half* ss = (half *)(shmem__ + warp_id * (R * D + 32) + 2*R*D);
|
half2* ss = (half2*)(data_flash_attn_shmem + warp_id * (R*D + 32) + 2*R*D);
|
||||||
|
|
||||||
const int tiih = lane_id % tph; // thread index in head
|
const int tiih = lane_id % tph; // thread index in head
|
||||||
const int hiisg = lane_id / tph; // head index in warp
|
const int hiiw = lane_id / tph; // head index in warp
|
||||||
|
|
||||||
const int D4 = D/4;
|
const int D2 = D / 2; // number of half2 to store head_dim row
|
||||||
|
|
||||||
// load R heads from Q to shared memory
|
// load R heads from Q to shared memory
|
||||||
for (int64_t i = 0; i < D4/tph; ++i) {
|
for (int i = 0; i < D2/tph; ++i) {
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
pq4[hiisg*D4 + tph*i + tiih] = ((half4*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
|
pq2[hiiw*D2 + tph*i + tiih] = ((half2*)(q + (iq1*nb01 + iq2*nb02 + iq3*nb03)))[tph*i + tiih];
|
||||||
}
|
}
|
||||||
|
|
||||||
ps4[hiisg*D4 + tph*i + tiih] = make_half4(0.0);
|
ps2[hiiw*D2 + tph*i + tiih] = make_half2(0.0, 0.0);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
half S(0.0);
|
half2 S = make_half2(0.0, 0.0);
|
||||||
half M(-INFINITY);
|
half2 M = make_half2(-INFINITY, -INFINITY);
|
||||||
|
|
||||||
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
|
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
|
||||||
const half mv = mp ? mp[ic] : 0.0;
|
const half2 mv = make_half2(mp ? mp[ic] : 0.0, 0.0);
|
||||||
if (__hisinf(mv) == -1) { // mv == -INFINITY
|
if (__hisinf(mv.x) == -1) { // mv == -INFINITY
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
|
half2 * pk2 = (half2 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K
|
half2 * pv2 = (half2 *) ((char *) v + (ic*nb11 + ik2*nb12 + ik3*nb13)); // assumes V same shape of K
|
||||||
|
|
||||||
half4 s4 = make_half4(0.0);
|
half2 s2 = make_half2(0.0, 0.0);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < D4/tph; ++i) {
|
for (int i = 0; i < D2/tph; ++i) {
|
||||||
s4 = __h4fma(pq4[hiisg*D4 + tph*i + tiih], pk4[tph*i + tiih], s4);
|
s2 = pq2[hiiw*D2 + tph*i + tiih] * pk2[tph*i + tiih] + s2;
|
||||||
}
|
}
|
||||||
|
|
||||||
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);
|
ss[hiiw*tph + tiih] = __half2half2(s2.x + s2.y);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (tiih == 0) {
|
if (tiih == 0) {
|
||||||
half s = 0.0;
|
half2 s = make_half2(0.0, 0.0);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < tph; ++i) {
|
for (int i = 0; i < tph; ++i) {
|
||||||
s += ss[hiisg*tph + i];
|
s += ss[hiiw*tph + i];
|
||||||
}
|
}
|
||||||
|
|
||||||
s = __hfma(s, __float2half(scale), mv); // s*scale + mv
|
s = s * scale_h + mv; // s*scale + mv
|
||||||
|
|
||||||
const half m = M;
|
half2 m = M;
|
||||||
|
|
||||||
M = __hmax(M, s);
|
M = __hmax2(M, s);
|
||||||
|
|
||||||
const half ms = hexp(m - M);
|
half2 ms = h2exp(m - M);
|
||||||
const half vs = hexp(s - M);
|
half2 vs = h2exp(s - M);
|
||||||
|
|
||||||
S = __hfma(S, ms, vs);
|
S = S * ms + vs;
|
||||||
|
|
||||||
ss[2*hiisg + 0] = ms;
|
ss[2*hiiw + 0] = ms;
|
||||||
ss[2*hiisg + 1] = vs;
|
ss[2*hiiw + 1] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const half ms = ss[2*hiisg + 0];
|
half2 ms = ss[2*hiiw + 0];
|
||||||
const half vs = ss[2*hiisg + 1];
|
half2 vs = ss[2*hiiw + 1];
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < D4/tph; ++i) {
|
for (int i = 0; i < D2/tph; ++i) {
|
||||||
ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms, __h4mul(pv4[tph*i + tiih], vs));
|
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms + pv2[tph*i + tiih] * vs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tiih == 0) {
|
if (tiih == 0) {
|
||||||
ss[2*hiisg + 0] = S;
|
ss[2*hiiw + 0] = S;
|
||||||
ss[2*hiisg + 1] = M;
|
ss[2*hiiw + 1] = M;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
@ -6297,31 +6261,31 @@ static __global__ void flash_attn_ext_f16(
|
||||||
// reduce the warps
|
// reduce the warps
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int sg = 1; sg < nwraps; ++sg) {
|
for (int sg = 1; sg < nwraps; ++sg) {
|
||||||
const half S0 = ss[ 2*hiisg + 0];
|
half2 S0 = ss[ 2*hiiw + 0];
|
||||||
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];
|
half2 S1 = ss[sg*(R*D + 32) + 2*hiiw + 0];
|
||||||
|
|
||||||
const half M0 = ss[ 2*hiisg + 1];
|
half2 M0 = ss[ 2*hiiw + 1];
|
||||||
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
|
half2 M1 = ss[sg*(R*D + 32) + 2*hiiw + 1];
|
||||||
|
|
||||||
M = __hmax(M0, M1);
|
M = __hmax2(M0, M1);
|
||||||
|
|
||||||
const half ms0 = hexp(M0 - M);
|
half2 ms0 = h2exp(M0 - M);
|
||||||
const half ms1 = hexp(M1 - M);
|
half2 ms1 = h2exp(M1 - M);
|
||||||
|
|
||||||
S = __hfma(S0, ms0, __hmul(S1, ms1));
|
S = S0 * ms0 + S1 * ms1;
|
||||||
|
|
||||||
if (tiih == 0) {
|
if (tiih == 0) {
|
||||||
ss[2*hiisg + 0] = S;
|
ss[2*hiiw + 0] = S;
|
||||||
ss[2*hiisg + 1] = M;
|
ss[2*hiiw + 1] = M;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < D4/tph; ++i) {
|
for (int i = 0; i < D2/tph; ++i) {
|
||||||
ps4[hiisg*D4 + tph*i + tiih] = __h4fma(ps4[hiisg*D4 + tph*i + tiih], ms0, __h4mul(ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih], ms1));
|
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih] * ms1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < D4/tph; ++i) {
|
for (int i = 0; i < D2/tph; ++i) {
|
||||||
ps4[hiisg*D4 + tph*i + tiih] = __h4div(ps4[hiisg*D4 + tph*i + tiih], S);
|
ps2[hiiw*D2 + tph*i + tiih] = __h2div(ps2[hiiw*D2 + tph*i + tiih], S);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6332,17 +6296,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int i2 = iq2;
|
const int i2 = iq2;
|
||||||
const int i3 = iq3;
|
const int i3 = iq3;
|
||||||
|
|
||||||
float4 * dst4 = (float4 *) kqv;
|
float2 * dst2 = (float2 *) kqv;
|
||||||
|
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int i = 0; i < D4/tph; ++i) {
|
for (int i = 0; i < D2/tph; ++i) {
|
||||||
float4 dst_ =
|
dst2[(i3*ne2*ne1 + i2 + i1*ne1)*D2 + tph*i + tiih] = __half22float2(ps2[hiiw*D2 + tph*i + tiih]);
|
||||||
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih];
|
|
||||||
half4 src_ = ps4[hiisg*D4 + tph*i + tiih];
|
|
||||||
dst_.x = __half2float(src_.x);
|
|
||||||
dst_.y = __half2float(src_.y);
|
|
||||||
dst_.z = __half2float(src_.z);
|
|
||||||
dst_.w = __half2float(src_.w);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7741,7 +7698,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
|
||||||
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
|
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
|
||||||
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
|
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
|
||||||
int num_blocks = num_heads * seq_len;
|
int num_blocks = num_heads * seq_len;
|
||||||
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
|
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024, 64><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
|
||||||
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
|
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10342,11 +10299,11 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
const int nwarps = 32;
|
const int nwarps = 32;
|
||||||
const int nhpw = 2; // heads per warp
|
const int nhpw = 2; // heads per warp
|
||||||
|
|
||||||
dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1)/(nhpw), Q->ne[3]);
|
dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1) / nhpw, Q->ne[3]);
|
||||||
dim3 block_dim(32, nwarps, 1);
|
dim3 block_dim(32 * nwarps, 1, 1);
|
||||||
|
|
||||||
int shmem = (nhpw*Q->ne[0] + nwarps*(nhpw*Q->ne[0] + 32))*(sizeof(float)/2);
|
|
||||||
|
|
||||||
|
int shmem = (nhpw*Q->ne[0]*2 + nwarps*(nhpw*Q->ne[0] + 32)) * (sizeof(float)/2);
|
||||||
|
printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]);
|
||||||
switch (Q->ne[0])
|
switch (Q->ne[0])
|
||||||
{
|
{
|
||||||
case 64:
|
case 64:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue