update impl
This commit is contained in:
parent
7cea9735ab
commit
2455a8d6c3
1 changed files with 23 additions and 20 deletions
43
ggml-cuda.cu
43
ggml-cuda.cu
|
@ -6134,9 +6134,9 @@ static __global__ void flash_attn_f32(
|
|||
}
|
||||
}
|
||||
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, __half> half16x16_acc;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b;
|
||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
||||
|
||||
// based on metal version
|
||||
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
|
||||
|
@ -6196,8 +6196,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
for (int i = 0; i < L2; ++i) {
|
||||
// load heads from Q to shared memory
|
||||
for (int j = warp_id; j < Q; j += n_warps) {
|
||||
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||
if (iq1 + j < ne01) {
|
||||
pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id];
|
||||
pq2[j*T2 + N4*i + lane_id] = __float22half2_rn(q2[N4*i + lane_id]);
|
||||
} else {
|
||||
pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
|
||||
}
|
||||
|
@ -6218,8 +6219,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
__syncthreads();
|
||||
|
||||
{
|
||||
half S[Q] = { 0.0 };
|
||||
half M[Q] = { -INFINITY };
|
||||
half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros??
|
||||
half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization
|
||||
|
||||
// assume K and V are same shape
|
||||
const int ne22 = ne12;
|
||||
|
@ -6277,12 +6278,12 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// Q*K^T
|
||||
{
|
||||
half16x16_a mq{};
|
||||
half16x16_b mk{};
|
||||
half16x16_acc mqk{};
|
||||
half16x16_a mq;
|
||||
half16x16_b mk;
|
||||
half16x16_acc mqk;
|
||||
|
||||
for (int cc = 0; cc < C/16; ++cc) {
|
||||
nvcuda::wmma::fill_fragment(mqk, 0); // re fetch
|
||||
nvcuda::wmma::fill_fragment(mqk, 0);
|
||||
|
||||
const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||
|
||||
|
@ -6297,8 +6298,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
|
||||
// online softmax
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
const int64_t p = lane_id;
|
||||
for (int j = 0; j < Q; ++j) {
|
||||
const int p = lane_id;
|
||||
|
||||
const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]);
|
||||
|
||||
|
@ -6311,6 +6312,10 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
||||
|
||||
for (int i = 0; i < L2; ++i) {
|
||||
ps2[j*T2 + N4*i + lane_id] *= __half2half2(ms);
|
||||
}
|
||||
|
||||
ss[j*T + p] = vs;
|
||||
}
|
||||
|
||||
|
@ -6318,9 +6323,9 @@ static __global__ void flash_attn_ext_f16(
|
|||
|
||||
// (Q*K^T)*V
|
||||
{
|
||||
half16x16_acc mqkv{};
|
||||
half16x16_a mqk{};
|
||||
half16x16_b mv{};
|
||||
half16x16_acc mqkv;
|
||||
half16x16_a mqk;
|
||||
half16x16_b mv;
|
||||
|
||||
for (int64_t i = 0; i < D16; ++i) {
|
||||
nvcuda::wmma::fill_fragment(mqkv, 0);
|
||||
|
@ -6353,7 +6358,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
// TODO: try parallel reduce
|
||||
if (warp_id == 0) {
|
||||
half S = 0.0;
|
||||
half M = -INFINITY;
|
||||
half M = __float2half(-INFINITY);
|
||||
|
||||
for (int64_t sg = 1; sg < n_warps; ++sg) {
|
||||
for (int64_t j = 0; j < Q; ++j) {
|
||||
|
@ -6395,10 +6400,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
|
@ -10366,7 +10369,7 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c
|
|||
|
||||
|
||||
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(mask->type == GGML_TYPE_F32);
|
||||
|
@ -10390,7 +10393,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
|||
float scale;
|
||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||
|
||||
const int nwarps = Q->ne[1] < 4 ? 4 : 2;
|
||||
const int nwarps = Q->ne[1] < 4 ? 12 : 4;
|
||||
const int nqpb = 16; // queries per block
|
||||
const int ncpw = 32; // cache values per warp (does not work for other values)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue