fixup! CUDA: add FP32 FlashAttention vector kernel
This commit is contained in:
parent
bbeb952aca
commit
41f5f3a4e4
2 changed files with 58 additions and 8 deletions
|
@ -14,6 +14,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
|
const float max_bias,
|
||||||
|
const float m0,
|
||||||
|
const float m1,
|
||||||
|
const uint32_t n_head_log2,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -49,6 +53,18 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const int stride_KV = nb11 / sizeof(half);
|
const int stride_KV = nb11 / sizeof(half);
|
||||||
const int stride_KV2 = nb11 / sizeof(half2);
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
|
half slopeh = __float2half(1.0f);
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const int h = blockIdx.y;
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
slopeh = __float2half(powf(base, exph));
|
||||||
|
}
|
||||||
|
|
||||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
constexpr int nwarps = D / WARP_SIZE;
|
constexpr int nwarps = D / WARP_SIZE;
|
||||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
@ -132,7 +148,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
sum2[j] = warp_reduce_sum(sum2[j]);
|
sum2[j] = warp_reduce_sum(sum2[j]);
|
||||||
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
|
||||||
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
|
||||||
if (ncols == 1) {
|
if (ncols == 1) {
|
||||||
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
|
||||||
|
@ -244,8 +260,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
|
||||||
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||||
const int shmem = 0;
|
const int shmem = 0;
|
||||||
|
|
||||||
float scale;
|
float scale = 1.0f;
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
|
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||||
|
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
|
const uint32_t n_head = Q->ne[2];
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
|
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
@ -254,7 +279,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
|
||||||
(const char *) V->data,
|
(const char *) V->data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
scale,
|
scale, max_bias, m0, m1, n_head_log2,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
|
|
@ -14,6 +14,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
float * __restrict__ dst,
|
float * __restrict__ dst,
|
||||||
float2 * __restrict__ dst_meta,
|
float2 * __restrict__ dst_meta,
|
||||||
const float scale,
|
const float scale,
|
||||||
|
const float max_bias,
|
||||||
|
const float m0,
|
||||||
|
const float m1,
|
||||||
|
const uint32_t n_head_log2,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -48,6 +52,18 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
const int stride_KV = nb11 / sizeof(half);
|
const int stride_KV = nb11 / sizeof(half);
|
||||||
const int stride_KV2 = nb11 / sizeof(half2);
|
const int stride_KV2 = nb11 / sizeof(half2);
|
||||||
|
|
||||||
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const int h = blockIdx.y;
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
slope = powf(base, exph);
|
||||||
|
}
|
||||||
|
|
||||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
constexpr int nwarps = D / WARP_SIZE;
|
constexpr int nwarps = D / WARP_SIZE;
|
||||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
@ -127,7 +143,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
sum[j] = warp_reduce_sum(sum[j]);
|
sum[j] = warp_reduce_sum(sum[j]);
|
||||||
sum[j] += mask ? __half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||||
|
|
||||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);
|
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);
|
||||||
|
|
||||||
|
@ -230,8 +246,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
|
||||||
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||||
const int shmem = 0;
|
const int shmem = 0;
|
||||||
|
|
||||||
float scale;
|
float scale = 1.0f;
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
|
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||||
|
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
|
const uint32_t n_head = Q->ne[2];
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
|
flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
@ -240,7 +265,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
|
||||||
(const char *) V->data,
|
(const char *) V->data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
scale,
|
scale, max_bias, m0, m1, n_head_log2,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue