diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bafb2ff1c..aeb07c964 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5989,38 +5989,55 @@ static __global__ void im2col_f32_f16( #define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 -template -static __global__ void flash_attn_f32(const float* q, const float* k,const float* v, float* dst, float kq_scale, - int d_head, int seq_len, int num_heads) { +template +static __global__ void flash_attn_f32( + const float* __restrict__ q, + const float* __restrict__ k, + const float* __restrict__ v, + float* __restrict__ kqv, + float kq_scale, + int head_dim, int seq_len, int num_heads) { const int head = blockIdx.x / seq_len; - const int head_size = d_head * seq_len; + const int head_size = head_dim * seq_len; const int s = blockIdx.x % seq_len; - const int tid = threadIdx.x; - extern __shared__ char work_data[]; - float* S = (float*)work_data; // theorical sequent length: 12848, due memory per block limit - float* warp_data = (float*)(work_data + seq_len * sizeof(float)); + extern __shared__ char shmem__[]; + float* S = (float*)shmem__; + float* warp_data = (float*)(shmem__ + seq_len * sizeof(float)); // QK^T - for(int is = tid; is < seq_len; is += block_size) { + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + const int key_offset = is * head_dim + head * head_size; + const int query_offset = s * head_dim + head * head_size; + S[is] = 0.0f; - int key_offset = is * d_head + head * head_size; - int query_offset = s * d_head + head * head_size; - for(int d = 0; d < d_head; d++) { + for(int d = 0; d < head_dim; d++) { S[is] += k[key_offset + d] * q[query_offset + d]; } S[is] *= kq_scale; } - __syncthreads(); float max_val = -INFINITY; // get the max - for(int is = tid; is < seq_len; is += block_size) { + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + max_val = fmaxf(max_val , S[is]); } max_val = warp_reduce_max(max_val); + { // get max from all threads int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; @@ -6034,14 +6051,20 @@ static __global__ void flash_attn_f32(const float* q, const float* k,const float // softmax(QK^T) float sum = 0.0f; - for(int is = tid; is < seq_len;is += block_size) { - const float val = expf(S[is] - max_val); - S[is] = val; - sum += val; + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + + S[is] = expf(S[is] - max_val); + sum += S[is]; } + __syncthreads(); sum = warp_reduce_sum(sum); - { // sum partials + { // softmax sum partials int warp_id = threadIdx.x / WARP_SIZE; int lane_id = threadIdx.x % WARP_SIZE; if (lane_id == 0) { @@ -6053,19 +6076,31 @@ static __global__ void flash_attn_f32(const float* q, const float* k,const float } float inv_sum = 1.0f / sum; - for(int is = tid; is < seq_len; is += block_size) { + #pragma unroll + for(int is0 = 0; is0 < k_seq_len; is0 += block_size) { + const int is = threadIdx.x + is0; + if(is >= seq_len) { + break; + } + S[is] *= inv_sum; } - __syncthreads(); + // softmax(QK^T)V - for (int d = tid; d < d_head; d += block_size) { - int dst_index = d + s * d_head + head * head_size; - int value_offset = d * seq_len + head * head_size; - dst[dst_index] = 0.0f; - for(int ic = 0; ic < seq_len; ic++) { - dst[dst_index] += v[value_offset + ic] * S[ic]; + for (int d = threadIdx.x; d < head_dim; d += block_size) { + const int dst_index = d + s * head_dim + head * head_size; + const int value_offset = d * seq_len + head * head_size; + + float temp = 0.0f; + #pragma unroll + for(int ic = 0; ic < k_seq_len;ic++) { + if(ic >= seq_len) { + break; + } + temp += v[value_offset + ic] * S[ic]; } + kqv[dst_index] = temp; } } @@ -7462,7 +7497,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst, static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); int num_blocks = num_heads * seq_len; - flash_attn_f32<<>>( + flash_attn_f32<<>>( q, k, v, dst, kq_scale, d_head, seq_len, num_heads); } diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index fb5e2a8bc..74167ed86 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -23,8 +23,9 @@ struct test_model { struct ggml_tensor * k; struct ggml_tensor * v; ggml_backend_t backend = NULL; - ggml_backend_buffer_t buffer; - struct ggml_context * ctx; + ggml_backend_buffer_t buffer = NULL; + struct ggml_context * ctx = NULL; + bool naive_attn = false; }; static std::vector tensor_to_float(const ggml_tensor * t) { @@ -216,8 +217,16 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_cgraph * gf = ggml_new_graph(ctx0); - struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false); - ggml_build_forward_expand(gf, result); + if(!model.naive_attn) { + struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false); + ggml_build_forward_expand(gf, result); + } else { + struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); + kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0])); + kq = ggml_soft_max(ctx0, kq); + kq = ggml_mul_mat(ctx0, model.v, kq); + ggml_build_forward_expand(gf, kq); + } // delete the temporally context used to build the graph ggml_free(ctx0); @@ -330,15 +339,18 @@ struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backe int main(int argc, char ** argv) { bool compare_backend = false; + test_model model; for (int i = 1; i < argc; i++) { if (strcmp(argv[i], "comp") == 0) { compare_backend = true; + } else if (strcmp(argv[i], "naive") == 0) { + model.naive_attn = true; } } ggml_time_init(); - test_model model; + load_model(model, true); ggml_backend_buffer_t buf_compute; // for compute @@ -359,9 +371,11 @@ int main(int argc, char ** argv) } ggml_backend_t backend_cpu = ggml_backend_cpu_init(); - + uint64_t compute_time_us__ = ggml_time_us(); struct ggml_tensor * result = compute_graph(model, backend_cpu, allocr, compare_backend); if(!compare_backend) { + ggml_backend_synchronize(model.backend); + printf("computing time: %.4f ms\n", (ggml_time_us() - compute_time_us__) / 1000.0); float* data = new float[ggml_nelements(result)]; ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result));