update impl

This commit is contained in:
FSSRepo 2024-01-27 12:23:40 -05:00
parent 7cea9735ab
commit 2455a8d6c3

View file

@ -6134,9 +6134,9 @@ static __global__ void flash_attn_f32(
}
}
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, __half> half16x16_acc;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
// based on metal version
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
@ -6196,8 +6196,9 @@ static __global__ void flash_attn_ext_f16(
for (int i = 0; i < L2; ++i) {
// load heads from Q to shared memory
for (int j = warp_id; j < Q; j += n_warps) {
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
if (iq1 + j < ne01) {
pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id];
pq2[j*T2 + N4*i + lane_id] = __float22half2_rn(q2[N4*i + lane_id]);
} else {
pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
}
@ -6218,8 +6219,8 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();
{
half S[Q] = { 0.0 };
half M[Q] = { -INFINITY };
half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros??
half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization
// assume K and V are same shape
const int ne22 = ne12;
@ -6277,12 +6278,12 @@ static __global__ void flash_attn_ext_f16(
// Q*K^T
{
half16x16_a mq{};
half16x16_b mk{};
half16x16_acc mqk{};
half16x16_a mq;
half16x16_b mk;
half16x16_acc mqk;
for (int cc = 0; cc < C/16; ++cc) {
nvcuda::wmma::fill_fragment(mqk, 0); // re fetch
nvcuda::wmma::fill_fragment(mqk, 0);
const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
@ -6297,8 +6298,8 @@ static __global__ void flash_attn_ext_f16(
}
// online softmax
for (int64_t j = 0; j < Q; ++j) {
const int64_t p = lane_id;
for (int j = 0; j < Q; ++j) {
const int p = lane_id;
const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]);
@ -6311,6 +6312,10 @@ static __global__ void flash_attn_ext_f16(
S[j] = S[j]*ms + warp_reduce_sum(vs);
for (int i = 0; i < L2; ++i) {
ps2[j*T2 + N4*i + lane_id] *= __half2half2(ms);
}
ss[j*T + p] = vs;
}
@ -6318,9 +6323,9 @@ static __global__ void flash_attn_ext_f16(
// (Q*K^T)*V
{
half16x16_acc mqkv{};
half16x16_a mqk{};
half16x16_b mv{};
half16x16_acc mqkv;
half16x16_a mqk;
half16x16_b mv;
for (int64_t i = 0; i < D16; ++i) {
nvcuda::wmma::fill_fragment(mqkv, 0);
@ -6353,7 +6358,7 @@ static __global__ void flash_attn_ext_f16(
// TODO: try parallel reduce
if (warp_id == 0) {
half S = 0.0;
half M = -INFINITY;
half M = __float2half(-INFINITY);
for (int64_t sg = 1; sg < n_warps; ++sg) {
for (int64_t j = 0; j < Q; ++j) {
@ -6395,10 +6400,8 @@ static __global__ void flash_attn_ext_f16(
}
}
}
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
@ -10366,7 +10369,7 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
GGML_ASSERT(Q->type == GGML_TYPE_F16);
GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16);
GGML_ASSERT(mask->type == GGML_TYPE_F32);
@ -10390,7 +10393,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
const int nwarps = Q->ne[1] < 4 ? 4 : 2;
const int nwarps = Q->ne[1] < 4 ? 12 : 4;
const int nqpb = 16; // queries per block
const int ncpw = 32; // cache values per warp (does not work for other values)