cuda: add flash attention + test
This commit is contained in:
parent
4f4bf35f46
commit
f7bcfb0566
3 changed files with 526 additions and 1 deletions
142
ggml-cuda.cu
142
ggml-cuda.cu
|
@ -5987,6 +5987,88 @@ static __global__ void im2col_f32_f16(
|
|||
}
|
||||
}
|
||||
|
||||
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
|
||||
|
||||
template<int block_size>
|
||||
static __global__ void flash_attn_f32(const float* q, const float* k,const float* v, float* dst, float kq_scale,
|
||||
int d_head, int seq_len, int num_heads) {
|
||||
const int head = blockIdx.x / seq_len;
|
||||
const int head_size = d_head * seq_len;
|
||||
const int s = blockIdx.x % seq_len;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
extern __shared__ char work_data[];
|
||||
float* S = (float*)work_data; // theorical sequent length: 12848, due memory per block limit
|
||||
float* warp_data = (float*)(work_data + seq_len * sizeof(float));
|
||||
|
||||
// QK^T
|
||||
for(int is = tid; is < seq_len; is += block_size) {
|
||||
S[is] = 0.0f;
|
||||
int key_offset = is * d_head + head * head_size;
|
||||
int query_offset = s * d_head + head * head_size;
|
||||
for(int d = 0; d < d_head; d++) {
|
||||
S[is] += k[key_offset + d] * q[query_offset + d];
|
||||
}
|
||||
S[is] *= kq_scale;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
float max_val = -INFINITY;
|
||||
// get the max
|
||||
for(int is = tid; is < seq_len; is += block_size) {
|
||||
max_val = fmaxf(max_val , S[is]);
|
||||
}
|
||||
|
||||
max_val = warp_reduce_max(max_val);
|
||||
{ // get max from all threads
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
warp_data[warp_id] = max_val;
|
||||
}
|
||||
__syncthreads();
|
||||
max_val = warp_data[lane_id];
|
||||
max_val = warp_reduce_max(max_val);
|
||||
}
|
||||
|
||||
// softmax(QK^T)
|
||||
float sum = 0.0f;
|
||||
for(int is = tid; is < seq_len;is += block_size) {
|
||||
const float val = expf(S[is] - max_val);
|
||||
S[is] = val;
|
||||
sum += val;
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
{ // sum partials
|
||||
int warp_id = threadIdx.x / WARP_SIZE;
|
||||
int lane_id = threadIdx.x % WARP_SIZE;
|
||||
if (lane_id == 0) {
|
||||
warp_data[warp_id] = sum;
|
||||
}
|
||||
__syncthreads();
|
||||
sum = warp_data[lane_id];
|
||||
sum = warp_reduce_sum(sum);
|
||||
}
|
||||
|
||||
float inv_sum = 1.0f / sum;
|
||||
for(int is = tid; is < seq_len; is += block_size) {
|
||||
S[is] *= inv_sum;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// softmax(QK^T)V
|
||||
for (int d = tid; d < d_head; d += block_size) {
|
||||
int dst_index = d + s * d_head + head * head_size;
|
||||
int value_offset = d * seq_len + head * head_size;
|
||||
dst[dst_index] = 0.0f;
|
||||
for(int ic = 0; ic < seq_len; ic++) {
|
||||
dst[dst_index] += v[value_offset + ic] * S[ic];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq>
|
||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||
|
@ -7377,6 +7459,13 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
|
|||
im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||
}
|
||||
|
||||
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
|
||||
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
|
||||
int num_blocks = num_heads * seq_len;
|
||||
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
|
||||
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 256
|
||||
|
||||
|
@ -9900,6 +9989,51 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
|||
}
|
||||
}
|
||||
|
||||
inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) {
|
||||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(K->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(V->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
|
||||
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
|
||||
|
||||
ggml_cuda_set_device(g_main_device);
|
||||
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
|
||||
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
|
||||
|
||||
const int64_t d_head = Q->ne[0];
|
||||
const int64_t sequence_length = Q->ne[1];
|
||||
const int64_t num_heads = Q->ne[2];
|
||||
|
||||
GGML_ASSERT(Q->ne[0] == d_head);
|
||||
GGML_ASSERT(K->ne[0] == d_head);
|
||||
GGML_ASSERT(V->ne[1] == d_head);
|
||||
|
||||
GGML_ASSERT(Q->ne[1] == sequence_length);
|
||||
GGML_ASSERT(K->ne[1] == sequence_length);
|
||||
GGML_ASSERT(V->ne[0] == sequence_length);
|
||||
|
||||
GGML_ASSERT(Q->ne[2] == num_heads);
|
||||
GGML_ASSERT(K->ne[2] == num_heads);
|
||||
GGML_ASSERT(V->ne[2] == num_heads);
|
||||
|
||||
float KQ_scale = 1.0f / sqrtf((float)d_head);
|
||||
|
||||
flash_attn_f32_cuda(
|
||||
(float *) src0_extra->data_device[g_main_device], // Query
|
||||
(float *) src1_extra->data_device[g_main_device], // Key
|
||||
(float *) src2_extra->data_device[g_main_device], // Value
|
||||
(float *) dst_extra->data_device[g_main_device], // dst
|
||||
KQ_scale, d_head, sequence_length, num_heads, main_stream);
|
||||
}
|
||||
|
||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
|
||||
}
|
||||
|
@ -10168,6 +10302,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||
case GGML_OP_ARGSORT:
|
||||
func = ggml_cuda_argsort;
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN:
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
@ -10182,7 +10318,11 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
|||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return true;
|
||||
}
|
||||
func(tensor->src[0], tensor->src[1], tensor);
|
||||
if(tensor->op == GGML_OP_FLASH_ATTN) {
|
||||
ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
||||
} else {
|
||||
func(tensor->src[0], tensor->src[1], tensor);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue