cuda: port metal version flash_attn_ext
This commit is contained in:
parent
a689b02ad3
commit
6374bc5779
1 changed files with 304 additions and 1 deletions
305
ggml-cuda.cu
305
ggml-cuda.cu
|
@ -937,6 +937,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
s_sum[warp_id] = tmp;
|
s_sum[warp_id] = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
tmp = s_sum[lane_id];
|
tmp = s_sum[lane_id];
|
||||||
tmp = warp_reduce_sum(tmp);
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
@ -6106,6 +6107,211 @@ static __global__ void flash_attn_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct __align__(8) half4 {
|
||||||
|
half x;
|
||||||
|
half y;
|
||||||
|
half z;
|
||||||
|
half w;
|
||||||
|
};
|
||||||
|
|
||||||
|
// based on metal version
|
||||||
|
template<int D, int R> // head size, rows per block
|
||||||
|
static __global__ void flash_attn_ext_f16(
|
||||||
|
const char* __restrict__ q,
|
||||||
|
const char* __restrict__ k,
|
||||||
|
const char* __restrict__ v,
|
||||||
|
const char* __restrict__ mask,
|
||||||
|
float* __restrict__ kqv,
|
||||||
|
float scale,
|
||||||
|
int ne00,
|
||||||
|
int ne01,
|
||||||
|
int ne02,
|
||||||
|
int ne03,
|
||||||
|
int ne10,
|
||||||
|
int ne11,
|
||||||
|
int ne12,
|
||||||
|
int ne13,
|
||||||
|
int ne31,
|
||||||
|
int nb31,
|
||||||
|
int nb01,
|
||||||
|
int nb02,
|
||||||
|
int nb03,
|
||||||
|
int nb11,
|
||||||
|
int nb12,
|
||||||
|
int nb13,
|
||||||
|
int ne0,
|
||||||
|
int ne1,
|
||||||
|
int ne2,
|
||||||
|
int ne3) {
|
||||||
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
|
const int nwraps = blockDim.y; // number of warps
|
||||||
|
const int tph = WARP_SIZE / R; // threads per head
|
||||||
|
const int iq3 = blockIdx.z;
|
||||||
|
const int iq2 = blockIdx.y * R + lane_id / tph;
|
||||||
|
const int iq1 = blockIdx.x;
|
||||||
|
|
||||||
|
if(iq2 >= ne02) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcast
|
||||||
|
const int rk2 = ne02 / ne12;
|
||||||
|
const int rk3 = ne03 / ne13;
|
||||||
|
// assume the same K and V shape
|
||||||
|
// const int rv2 = ne02 / ne12;
|
||||||
|
// const int rv3 = ne03 / ne13;
|
||||||
|
|
||||||
|
// kv indices
|
||||||
|
const int ik2 = iq2 / rk2;
|
||||||
|
const int ik3 = iq3 / rk3;
|
||||||
|
const int iv2 = iq2 / rv2;
|
||||||
|
const int iv3 = iq3 / rv3;
|
||||||
|
|
||||||
|
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
|
||||||
|
|
||||||
|
const float * mp = mask ? mask + (ir % ne31)*nb31 : nullptr;
|
||||||
|
|
||||||
|
extern __shared__ char shmem__[];
|
||||||
|
|
||||||
|
half4* pq4 = (half4*)shmem__;
|
||||||
|
half4* ps4 = (half4*)(shmem__ + warp_id * (R * D + 32) + 1*R*D);
|
||||||
|
half* ss = (half *)(shmem__ + warp_id * (R * D + 32) + 2*R*D);
|
||||||
|
|
||||||
|
const int tiih = lane_id % tph; // thread index in head
|
||||||
|
const int hiisg = lane_id / tph; // head index in warp
|
||||||
|
|
||||||
|
const int D4 = D/4;
|
||||||
|
|
||||||
|
// load R heads from Q to shared memory
|
||||||
|
for (int64_t i = 0; i < D4/tph; ++i) {
|
||||||
|
if (warp_id == 0) {
|
||||||
|
pq4[hiisg*D4 + tph*i + tiih] = (const half4*)((const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03))[tph*i + tiih];
|
||||||
|
}
|
||||||
|
|
||||||
|
ps4[hiisg*D4 + tph*i + tiih] = 0.0h;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
half S = 0.0h;
|
||||||
|
half M = -INFINITY;
|
||||||
|
|
||||||
|
for (int64_t ic = warp_id; ic < ne11; ic += nwraps) {
|
||||||
|
const half mv = mp ? mp[ic] : 0.0h;
|
||||||
|
if (mv == -INFINITY) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const half4 * pk4 = (const half4 *) ((char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
const half4 * pv4 = (const half4 *) ((char *) v + (ic*nb11 + iv2*nb12 + iv3*nb13)); // assumes V same shape of K
|
||||||
|
|
||||||
|
half4 s4 = 0.0h;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < D4/tph; ++i) {
|
||||||
|
s4 += pq4[hiisg*D4 + tph*i + tiih] * pk4[tph*i + tiih];
|
||||||
|
}
|
||||||
|
|
||||||
|
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (tiih == 0) {
|
||||||
|
half s = 0.0h;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < tph; ++i) {
|
||||||
|
s += ss[hiisg*tph + i];
|
||||||
|
}
|
||||||
|
|
||||||
|
s = s*scale + mv;
|
||||||
|
|
||||||
|
const half m = M;
|
||||||
|
|
||||||
|
M = max(M, s);
|
||||||
|
|
||||||
|
const half ms = exp(m - M);
|
||||||
|
const half vs = exp(s - M);
|
||||||
|
|
||||||
|
S = S*ms + vs;
|
||||||
|
|
||||||
|
ss[2*hiisg + 0] = ms;
|
||||||
|
ss[2*hiisg + 1] = vs;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const half ms = ss[2*hiisg + 0];
|
||||||
|
const half vs = ss[2*hiisg + 1];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < D4/tph; ++i) {
|
||||||
|
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms + pv4[tph*i + tiih]*vs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tiih == 0) {
|
||||||
|
ss[2*hiisg + 0] = S;
|
||||||
|
ss[2*hiisg + 1] = M;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// reduce the warps
|
||||||
|
if (warp_id == 0) {
|
||||||
|
for (int sg = 1; sg < nwraps; ++sg) {
|
||||||
|
const half S0 = ss[ 2*hiisg + 0];
|
||||||
|
const half S1 = ss[sg*(R*D + 32) + 2*hiisg + 0];
|
||||||
|
|
||||||
|
const half M0 = ss[ 2*hiisg + 1];
|
||||||
|
const half M1 = ss[sg*(R*D + 32) + 2*hiisg + 1];
|
||||||
|
|
||||||
|
M = max(M0, M1);
|
||||||
|
|
||||||
|
const half ms0 = exp(M0 - M);
|
||||||
|
const half ms1 = exp(M1 - M);
|
||||||
|
|
||||||
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
if (tiih == 0) {
|
||||||
|
ss[2*hiisg + 0] = S;
|
||||||
|
ss[2*hiisg + 1] = M;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < D4/tph; ++i) {
|
||||||
|
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]*ms0 + ps4[sg*(R*D + 32)/4 + hiisg*D4 + tph*i + tiih]*ms1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < D4/tph; ++i) {
|
||||||
|
ps4[hiisg*D4 + tph*i + tiih] = ps4[hiisg*D4 + tph*i + tiih]/S;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// dst indices
|
||||||
|
const int i1 = iq1;
|
||||||
|
const int i2 = iq2;
|
||||||
|
const int i3 = iq3;
|
||||||
|
|
||||||
|
float4 * dst4 = (float4 *) kqv;
|
||||||
|
|
||||||
|
if (warp_id == 0) {
|
||||||
|
for (int i = 0; i < D4/tph; ++i) {
|
||||||
|
float4 dst_ =
|
||||||
|
dst4[(i3*ne2*ne1 + i2 + i1*ne1)*D4 + tph*i + tiih];
|
||||||
|
half4 src_ = ps4[hiisg*D4 + tph*i + tiih];
|
||||||
|
dst_.x = __half2float(src_.x);
|
||||||
|
dst_.y = __half2float(src_.y);
|
||||||
|
dst_.z = __half2float(src_.z);
|
||||||
|
dst_.w = __half2float(src_.w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<int qk, int qr, dequantize_kernel_t dq>
|
template<int qk, int qr, dequantize_kernel_t dq>
|
||||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||||
|
@ -10071,6 +10277,98 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c
|
||||||
KQ_scale, d_head, sequence_length, num_heads, main_stream);
|
KQ_scale, d_head, sequence_length, num_heads, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
||||||
|
GGML_ASSERT(Q->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(mask->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
|
||||||
|
|
||||||
|
ggml_cuda_set_device(g_main_device);
|
||||||
|
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||||
|
|
||||||
|
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
|
||||||
|
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
|
||||||
|
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
|
||||||
|
ggml_tensor_extra_gpu * src3_extra = (ggml_tensor_extra_gpu *) mask->extra;
|
||||||
|
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
|
const int nwarps = 32;
|
||||||
|
const int nhpw = 2; // heads per warp
|
||||||
|
|
||||||
|
dim3 blocks_num(Q->ne[1], (Q->ne[2] + nhpw - 1)/(nhpw), Q->ne[3]);
|
||||||
|
dim3 block_dim(32, nwarps, 1);
|
||||||
|
|
||||||
|
int shmem = (nhpw*Q->ne[0] + nwarps*(nhpw*Q->ne[0] + 32))*(sizeof(float)/2);
|
||||||
|
|
||||||
|
switch (Q->ne[0])
|
||||||
|
{
|
||||||
|
case 64:
|
||||||
|
flash_attn_ext_f16<64, 2>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
(const char *) src3_extra->data_device[g_main_device], // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask->ne[1], mask->nb[1],
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
flash_attn_ext_f16<80, 2>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
(const char *) src3_extra->data_device[g_main_device], // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask->ne[1], mask->nb[1],
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
flash_attn_ext_f16<128, 2>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
(const char *) src3_extra->data_device[g_main_device], // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask->ne[1], mask->nb[1],
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
|
||||||
}
|
}
|
||||||
|
@ -10341,6 +10639,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
||||||
break;
|
break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
case GGML_OP_FLASH_ATTN:
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -10357,7 +10657,9 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
||||||
}
|
}
|
||||||
if(tensor->op == GGML_OP_FLASH_ATTN) {
|
if(tensor->op == GGML_OP_FLASH_ATTN) {
|
||||||
ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
||||||
} else {
|
} else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||||
|
ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||||
|
} else {
|
||||||
func(tensor->src[0], tensor->src[1], tensor);
|
func(tensor->src[0], tensor->src[1], tensor);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -11175,6 +11477,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue