update impl
This commit is contained in:
parent
7cea9735ab
commit
2455a8d6c3
1 changed files with 23 additions and 20 deletions
43
ggml-cuda.cu
43
ggml-cuda.cu
|
@ -6134,9 +6134,9 @@ static __global__ void flash_attn_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_a;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_a;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, __half, nvcuda::wmma::col_major> half16x16_b;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_b;
|
||||||
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, __half> half16x16_acc;
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
||||||
|
|
||||||
// based on metal version
|
// based on metal version
|
||||||
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
|
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per blocks
|
||||||
|
@ -6196,8 +6196,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int i = 0; i < L2; ++i) {
|
for (int i = 0; i < L2; ++i) {
|
||||||
// load heads from Q to shared memory
|
// load heads from Q to shared memory
|
||||||
for (int j = warp_id; j < Q; j += n_warps) {
|
for (int j = warp_id; j < Q; j += n_warps) {
|
||||||
|
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
if (iq1 + j < ne01) {
|
if (iq1 + j < ne01) {
|
||||||
pq2[j*T2 + N4*i + lane_id] = ((half2*) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)))[N4*i + lane_id];
|
pq2[j*T2 + N4*i + lane_id] = __float22half2_rn(q2[N4*i + lane_id]);
|
||||||
} else {
|
} else {
|
||||||
pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
|
pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0);
|
||||||
}
|
}
|
||||||
|
@ -6218,8 +6219,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
{
|
{
|
||||||
half S[Q] = { 0.0 };
|
half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros??
|
||||||
half M[Q] = { -INFINITY };
|
half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization
|
||||||
|
|
||||||
// assume K and V are same shape
|
// assume K and V are same shape
|
||||||
const int ne22 = ne12;
|
const int ne22 = ne12;
|
||||||
|
@ -6277,12 +6278,12 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
half16x16_a mq{};
|
half16x16_a mq;
|
||||||
half16x16_b mk{};
|
half16x16_b mk;
|
||||||
half16x16_acc mqk{};
|
half16x16_acc mqk;
|
||||||
|
|
||||||
for (int cc = 0; cc < C/16; ++cc) {
|
for (int cc = 0; cc < C/16; ++cc) {
|
||||||
nvcuda::wmma::fill_fragment(mqk, 0); // re fetch
|
nvcuda::wmma::fill_fragment(mqk, 0);
|
||||||
|
|
||||||
const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
|
@ -6297,8 +6298,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int j = 0; j < Q; ++j) {
|
||||||
const int64_t p = lane_id;
|
const int p = lane_id;
|
||||||
|
|
||||||
const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]);
|
const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]);
|
||||||
|
|
||||||
|
@ -6311,6 +6312,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
||||||
|
|
||||||
|
for (int i = 0; i < L2; ++i) {
|
||||||
|
ps2[j*T2 + N4*i + lane_id] *= __half2half2(ms);
|
||||||
|
}
|
||||||
|
|
||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6318,9 +6323,9 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
// (Q*K^T)*V
|
// (Q*K^T)*V
|
||||||
{
|
{
|
||||||
half16x16_acc mqkv{};
|
half16x16_acc mqkv;
|
||||||
half16x16_a mqk{};
|
half16x16_a mqk;
|
||||||
half16x16_b mv{};
|
half16x16_b mv;
|
||||||
|
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int64_t i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::fill_fragment(mqkv, 0);
|
nvcuda::wmma::fill_fragment(mqkv, 0);
|
||||||
|
@ -6353,7 +6358,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
// TODO: try parallel reduce
|
// TODO: try parallel reduce
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
half S = 0.0;
|
half S = 0.0;
|
||||||
half M = -INFINITY;
|
half M = __float2half(-INFINITY);
|
||||||
|
|
||||||
for (int64_t sg = 1; sg < n_warps; ++sg) {
|
for (int64_t sg = 1; sg < n_warps; ++sg) {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
@ -6395,10 +6400,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<int qk, int qr, dequantize_kernel_t dq>
|
template<int qk, int qr, dequantize_kernel_t dq>
|
||||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||||
|
@ -10366,7 +10369,7 @@ inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, c
|
||||||
|
|
||||||
|
|
||||||
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
||||||
GGML_ASSERT(Q->type == GGML_TYPE_F16);
|
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(mask->type == GGML_TYPE_F32);
|
GGML_ASSERT(mask->type == GGML_TYPE_F32);
|
||||||
|
@ -10390,7 +10393,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
float scale;
|
float scale;
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
const int nwarps = Q->ne[1] < 4 ? 4 : 2;
|
const int nwarps = Q->ne[1] < 4 ? 12 : 4;
|
||||||
const int nqpb = 16; // queries per block
|
const int nqpb = 16; // queries per block
|
||||||
const int ncpw = 32; // cache values per warp (does not work for other values)
|
const int ncpw = 32; // cache values per warp (does not work for other values)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue