update impl

This commit is contained in:
FSSRepo 2024-01-27 12:23:40 -05:00
parent 7cea9735ab
commit 2455a8d6c3

View file

@ -6134,9 +6134,9 @@ static __global__ void flash_attn_f32(
} }
} }
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_a; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_b; typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, __half> half16x16_acc; typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
// based on metal version // based on metal version
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
@ -6196,8 +6196,9 @@ static __global__ void flash_attn_ext_f16(
for (int i = 0; i < L2; ++i) { for (int i = 0; i < L2; ++i) {
// load heads from Q to shared memory // load heads from Q to shared memory
for (int j = warp_id; j < Q; j += n_warps) { for (int j = warp_id; j < Q; j += n_warps) {
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
if (iq1 + j < ne01) { if (iq1 + j < ne01) {
pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id]; pq2[j*T2 + N4*i + lane_id] = __float22half2_rn(q2[N4*i + lane_id]);
} else { } else {
pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
} }
@ -6218,8 +6219,8 @@ static __global__ void flash_attn_ext_f16(
__syncthreads(); __syncthreads();
{ {
half S[Q] = { 0.0 }; half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros??
half M[Q] = { -INFINITY }; half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization
// assume K and V are same shape // assume K and V are same shape
const int ne22 = ne12; const int ne22 = ne12;
@ -6277,12 +6278,12 @@ static __global__ void flash_attn_ext_f16(
// Q*K^T // Q*K^T
{ {
half16x16_a mq{}; half16x16_a mq;
half16x16_b mk{}; half16x16_b mk;
half16x16_acc mqk{}; half16x16_acc mqk;
for (int cc = 0; cc < C/16; ++cc) { for (int cc = 0; cc < C/16; ++cc) {
nvcuda::wmma::fill_fragment(mqk, 0); // re fetch nvcuda::wmma::fill_fragment(mqk, 0);
const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
@ -6297,8 +6298,8 @@ static __global__ void flash_attn_ext_f16(
} }
// online softmax // online softmax
for (int64_t j = 0; j < Q; ++j) { for (int j = 0; j < Q; ++j) {
const int64_t p = lane_id; const int p = lane_id;
const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]); const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]);
@ -6311,6 +6312,10 @@ static __global__ void flash_attn_ext_f16(
S[j] = S[j]*ms + warp_reduce_sum(vs); S[j] = S[j]*ms + warp_reduce_sum(vs);
for (int i = 0; i < L2; ++i) {
ps2[j*T2 + N4*i + lane_id] *= __half2half2(ms);
}
ss[j*T + p] = vs; ss[j*T + p] = vs;
} }
@ -6318,9 +6323,9 @@ static __global__ void flash_attn_ext_f16(
// (Q*K^T)*V // (Q*K^T)*V
{ {
half16x16_acc mqkv{}; half16x16_acc mqkv;
half16x16_a mqk{}; half16x16_a mqk;
half16x16_b mv{}; half16x16_b mv;
for (int64_t i = 0; i < D16; ++i) { for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(mqkv, 0); nvcuda::wmma::fill_fragment(mqkv, 0);
@ -6353,7 +6358,7 @@ static __global__ void flash_attn_ext_f16(
// TODO: try parallel reduce // TODO: try parallel reduce
if (warp_id == 0) { if (warp_id == 0) {
half S = 0.0; half S = 0.0;
half M = -INFINITY; half M = __float2half(-INFINITY);
for (int64_t sg = 1; sg < n_warps; ++sg) { for (int64_t sg = 1; sg < n_warps; ++sg) {
for (int64_t j = 0; j < Q; ++j) { for (int64_t j = 0; j < Q; ++j) {
@ -6395,10 +6400,8 @@ static __global__ void flash_attn_ext_f16(
} }
} }
} }
} }
template<int qk, int qr, dequantize_kernel_t dq> template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
@ -10366,7 +10369,7 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) { inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
GGML_ASSERT(Q->type == GGML_TYPE_F16); GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(mask->type == GGML_TYPE_F32); GGML_ASSERT(mask->type == GGML_TYPE_F32);
@ -10390,7 +10393,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
float scale; float scale;
memcpy(&scale, KQV->op_params, sizeof(float)); memcpy(&scale, KQV->op_params, sizeof(float));
const int nwarps = Q->ne[1] < 4 ? 4 : 2; const int nwarps = Q->ne[1] < 4 ? 12 : 4;
const int nqpb = 16; // queries per block const int nqpb = 16; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values) const int ncpw = 32; // cache values per warp (does not work for other values)