diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 568c411af..bb65ca642 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5987,6 +5987,88 @@ static __global__ void im2col_f32_f16( } } +#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256 + +template +static __global__ void flash_attn_f32(const float* q, const float* k,const float* v, float* dst, float kq_scale, + int d_head, int seq_len, int num_heads) { + const int head = blockIdx.x / seq_len; + const int head_size = d_head * seq_len; + const int s = blockIdx.x % seq_len; + const int tid = threadIdx.x; + + extern __shared__ char work_data[]; + float* S = (float*)work_data; // theorical sequent length: 12848, due memory per block limit + float* warp_data = (float*)(work_data + seq_len * sizeof(float)); + + // QK^T + for(int is = tid; is < seq_len; is += block_size) { + S[is] = 0.0f; + int key_offset = is * d_head + head * head_size; + int query_offset = s * d_head + head * head_size; + for(int d = 0; d < d_head; d++) { + S[is] += k[key_offset + d] * q[query_offset + d]; + } + S[is] *= kq_scale; + } + + __syncthreads(); + + float max_val = -INFINITY; + // get the max + for(int is = tid; is < seq_len; is += block_size) { + max_val = fmaxf(max_val , S[is]); + } + + max_val = warp_reduce_max(max_val); + { // get max from all threads + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = max_val; + } + __syncthreads(); + max_val = warp_data[lane_id]; + max_val = warp_reduce_max(max_val); + } + + // softmax(QK^T) + float sum = 0.0f; + for(int is = tid; is < seq_len;is += block_size) { + const float val = expf(S[is] - max_val); + S[is] = val; + sum += val; + } + + sum = warp_reduce_sum(sum); + { // sum partials + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + warp_data[warp_id] = sum; + } + __syncthreads(); + sum = warp_data[lane_id]; + sum = warp_reduce_sum(sum); + } + + float inv_sum = 1.0f / sum; + for(int is = tid; is < seq_len; is += block_size) { + S[is] *= inv_sum; + } + + __syncthreads(); + // softmax(QK^T)V + for (int d = tid; d < d_head; d += block_size) { + int dst_index = d + s * d_head + head * head_size; + int value_offset = d * seq_len + head * head_size; + dst[dst_index] = 0.0f; + for(int ic = 0; ic < seq_len; ic++) { + dst[dst_index] += v[value_offset + ic] * S[ic]; + } + } +} + template static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { @@ -7377,6 +7459,13 @@ static void im2col_f32_f16_cuda(const float* x, half* dst, im2col_f32_f16<<>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1); } +static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) { + int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float); + int num_blocks = num_heads * seq_len; + flash_attn_f32<<>>( + q, k, v, dst, kq_scale, d_head, seq_len, num_heads); +} + // buffer pool for cuda #define MAX_CUDA_BUFFERS 256 @@ -9900,6 +9989,51 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s } } +inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) { + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F32); + GGML_ASSERT(V->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); + GGML_ASSERT(K->backend == GGML_BACKEND_GPU); + GGML_ASSERT(V->backend == GGML_BACKEND_GPU); + GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + + ggml_cuda_set_device(g_main_device); + const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra; + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra; + ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra; + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra; + + const int64_t d_head = Q->ne[0]; + const int64_t sequence_length = Q->ne[1]; + const int64_t num_heads = Q->ne[2]; + + GGML_ASSERT(Q->ne[0] == d_head); + GGML_ASSERT(K->ne[0] == d_head); + GGML_ASSERT(V->ne[1] == d_head); + + GGML_ASSERT(Q->ne[1] == sequence_length); + GGML_ASSERT(K->ne[1] == sequence_length); + GGML_ASSERT(V->ne[0] == sequence_length); + + GGML_ASSERT(Q->ne[2] == num_heads); + GGML_ASSERT(K->ne[2] == num_heads); + GGML_ASSERT(V->ne[2] == num_heads); + + float KQ_scale = 1.0f / sqrtf((float)d_head); + + flash_attn_f32_cuda( + (float *) src0_extra->data_device[g_main_device], // Query + (float *) src1_extra->data_device[g_main_device], // Key + (float *) src2_extra->data_device[g_main_device], // Value + (float *) dst_extra->data_device[g_main_device], // dst + KQ_scale, d_head, sequence_length, num_heads, main_stream); +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -10168,6 +10302,8 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st case GGML_OP_ARGSORT: func = ggml_cuda_argsort; break; + case GGML_OP_FLASH_ATTN: + break; default: return false; } @@ -10182,7 +10318,11 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return true; } - func(tensor->src[0], tensor->src[1], tensor); + if(tensor->op == GGML_OP_FLASH_ATTN) { + ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor); + } else { + func(tensor->src[0], tensor->src[1], tensor); + } return true; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7c932240d..bc5649989 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -52,6 +52,8 @@ llama_build_and_test_executable(test-backend-ops.cpp) llama_build_and_test_executable(test-rope.cpp) +llama_build_executable(test-flash-attention.cpp) + # dummy executable - not installed get_filename_component(TEST_TARGET test-c.c NAME_WE) add_executable(${TEST_TARGET} test-c.c) diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp new file mode 100644 index 000000000..c99ad719d --- /dev/null +++ b/tests/test-flash-attention.cpp @@ -0,0 +1,383 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#define GGML_USE_CUBLAS + +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +struct test_model { + struct ggml_tensor * q; + struct ggml_tensor * k; + struct ggml_tensor * v; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + +static std::vector tensor_to_float(const ggml_tensor * t) { + std::vector tv; + tv.reserve(ggml_nelements(t)); + + std::vector buf(ggml_nbytes(t)); + ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t)); + + ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type); + size_t bs = ggml_blck_size(t->type); + std::vector vq(ggml_blck_size(t->type)); + bool quantized = ggml_is_quantized(t->type); + + // access elements by index to avoid gaps in views + for (int64_t i3 = 0; i3 < t->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < t->ne[2]; i2++) { + for (int64_t i1 = 0; i1 < t->ne[1]; i1++) { + for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) { + size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0]; + if (t->type == GGML_TYPE_F16) { + tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i])); + } else if (t->type == GGML_TYPE_F32) { + tv.push_back(*(float *) &buf[i]); + } else if (t->type == GGML_TYPE_I32) { + tv.push_back((float)*(int32_t *) &buf[i]); + } else if (t->type == GGML_TYPE_I16) { + tv.push_back((float)*(int16_t *) &buf[i]); + } else if (t->type == GGML_TYPE_I8) { + tv.push_back((float)*(int8_t *) &buf[i]); + } else if (quantized) { + std::vector vq(ggml_blck_size(t->type)); + tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type)); + tv.insert(tv.end(), vq.begin(), vq.end()); + } else { + GGML_ASSERT(false); + } + } + } + } + } + + return tv; +} + +// accept FLT_MAX as infinity +static bool isinf_or_max(float f) { + return std::isinf(f) || f == FLT_MAX || f == -FLT_MAX; +} + +// normalized mean squared error = mse(a, b) / mse(a, 0) +static double nmse(const float * a, const float * b, size_t n) { + double mse_a_b = 0.0; + double mse_a_0 = 0.0; + + for (size_t i = 0; i < n; i++) { + float a_i = a[i]; + float b_i = b[i]; + + mse_a_b += (a_i - b_i) * (a_i - b_i); + mse_a_0 += a_i * a_i; + } + + return mse_a_b / mse_a_0; +} + +void ggml_tensor_set_f32(struct ggml_tensor* tensor, float value, int l, int k = 0, int j = 0, int i = 0) { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]) = value; +} + +float ggml_tensor_get_f32(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) { + GGML_ASSERT(tensor->nb[0] == sizeof(float)); + return *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); +} + +void load_model(test_model & model, bool use_gpu = false) { + float Query[30] = { // [3, 4, 2] + // z0 + 2, 4, 2, + 4, 2, 1, + 4, 1, 3, + 4, 2, 2, + + // z1 + 2, 1, 1, + 4, 2, 1, + 1, 1, 3, + 4, 2, 1 + }; + + float Key[24] = { // [3, 4, 2] + // z0 + 2, 4, 2, + 4, 2, 1, + 4, 2, 3, + 1, 2, 1, + + // z1 + 3, 1, 3, + 4, 2, 1, + 1, 1, 2, + 4, 3, 1 + }; + + float Value[24] = { // [4, 3, 2] + // z0 + 2, 4, 2, 1, + 2, 1, 4, 2, + 1, 4, 2, 3, + + // z1 + 1, 4, 2, 1, + 2, 1, 1, 2, + 1, 4, 3, 3, + }; + + size_t buffer_size = 0; + { + buffer_size += 30 * ggml_type_sizef(GGML_TYPE_F32); // tensor q + buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor k + buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor v + buffer_size += 1024; + } + + printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 3; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUBLAS + if (use_gpu) { + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.q = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2); + model.k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2); + model.v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 4, 3, 2); + + // create a allocator + ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); + + // alloc memory + ggml_allocr_alloc(alloc, model.q); + ggml_allocr_alloc(alloc, model.k); + ggml_allocr_alloc(alloc, model.v); + + ggml_backend_tensor_set(model.q, Query, 0, ggml_nbytes(model.q)); + ggml_backend_tensor_set(model.k, Key, 0, ggml_nbytes(model.k)); + ggml_backend_tensor_set(model.v, Value, 0, ggml_nbytes(model.v)); + + ggml_allocr_free(alloc); +} + +struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false); + ggml_build_forward_expand(gf, result); + + // delete the temporally context used to build the graph + ggml_free(ctx0); + return gf; +} + +struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backend_cpu, struct ggml_allocr * allocr, bool compare_backends) { + // reset the allocator to free all the memory allocated during the previous inference + ggml_allocr_reset(allocr); + + struct ggml_cgraph * gf = build_graph(model, allocr); + + // allocate tensors + ggml_allocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + + if(!compare_backends) { + ggml_backend_graph_compute(model.backend, gf); + + // in this case, the output tensor is the last one in the graph + return gf->nodes[gf->n_nodes - 1]; + } + + struct callback_userdata { + bool ok; + double max_err; + ggml_backend_t backend1; + ggml_backend_t backend2; + }; + + callback_userdata ud { + true, + 1e-7, + model.backend, + backend_cpu + }; + + auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { + callback_userdata * ud = (callback_userdata *) user_data; + const char * bn1 = ggml_backend_name(ud->backend1); + const char * bn2 = ggml_backend_name(ud->backend2); + + if (t1->op == GGML_OP_NONE) { + // sentinels must be unchanged + std::vector t1_data(ggml_nbytes(t1)); + std::vector t2_data(ggml_nbytes(t2)); + ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1)); + ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2)); + + if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) { + printf("sentinel mismatch: %s ", t1->name); + ud->ok = false; + return true; + } + } + + std::vector f1 = tensor_to_float(t1); + std::vector f2 = tensor_to_float(t2); + + for (size_t i = 0; i < f1.size(); i++) { + // check for nans + if (std::isnan(f1[i]) || std::isnan(f2[i])) { + printf("[%s] NaN at index %zu (%s=%f %s=%f) ", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + // check for infs: both must be inf of the same sign, or both must be finite + if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { + if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { + if (std::signbit(f1[i]) != std::signbit(f2[i])) { + printf("[%s] inf sign mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + } else { + printf("[%s] inf mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]); + ud->ok = false; + return true; + } + } + } + + double err = nmse(f1.data(), f2.data(), f1.size()); + if (err > ud->max_err) { + printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + ud->ok = false; + } + + return true; + + GGML_UNUSED(index); + }; + + printf("\nTesting Flash Attention - comparing backends: "); + + const bool cmp_ok = ggml_backend_compare_graph_backend(model.backend, backend_cpu, gf, callback, &ud); + if (ud.ok && cmp_ok) { + printf("\033[1;32mOK\033[0m\n"); + return NULL; + } + + printf("\033[1;31mFAIL\033[0m\n"); + return NULL; +} + +int main(int argc, char ** argv) +{ + bool compare_backend = false; + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "comp") == 0) { + compare_backend = true; + } + } + + ggml_time_init(); + + test_model model; + load_model(model, true); + + ggml_backend_buffer_t buf_compute; // for compute + struct ggml_allocr * allocr = NULL; + + { + allocr = ggml_allocr_new_measure_from_backend(model.backend); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph(model, allocr); + size_t mem_size = ggml_allocr_alloc_graph(allocr, gf); + ggml_allocr_free(allocr); + + // compute the required memory + buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size); + allocr = ggml_allocr_new_from_buffer(buf_compute); + fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + } + + ggml_backend_t backend_cpu = ggml_backend_cpu_init(); + + struct ggml_tensor * result = compute_graph(model, backend_cpu, allocr, compare_backend); + if(!compare_backend) { + float* data = new float[ggml_nelements(result)]; + + ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result)); + printf("\nPerforming test:\n"); + + for(int i = 0; i < ggml_nelements(result); i ++) { + if(i > 0 && (i % result->ne[0] == 0)) { + printf("\n"); + } + printf("%2.6f ", data[i]); + } + } + + ggml_free(model.ctx); + + ggml_backend_buffer_free(model.buffer); + ggml_backend_buffer_free(buf_compute); + ggml_backend_free(model.backend); + return 0; +}