diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 5f6438048..5229e15d2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -125,6 +125,11 @@ #include "ggml.h" #include "ggml-backend-impl.h" +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) #define CC_PASCAL 600 @@ -679,7 +684,6 @@ static __device__ __forceinline__ half warp_reduce_max(half x) { return x; #else (void) x; - bad_arch(); #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX } @@ -6156,16 +6160,17 @@ static __global__ void flash_attn_f32( #if __CUDA_ARCH__ >= CC_VOLTA typedef nvcuda::wmma::fragment half16x16_a; typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; typedef nvcuda::wmma::fragment half16x16_acc; // based on metal version -template // D head size, Q queries per block, C cache items per blocks +template // D head size, Q queries per block, C cache items per block static __global__ void flash_attn_ext_f16( const char* __restrict__ q, const char* __restrict__ k, const char* __restrict__ v, const char* __restrict__ mask, - float* __restrict__ kqv, + float* __restrict__ dst, float scale, int ne00, int ne01, @@ -6190,57 +6195,64 @@ static __global__ void flash_attn_ext_f16( const int warp_id = threadIdx.y; const int lane_id = threadIdx.x; - const int n_warps = blockDim.y; // number of warps + const int num_warps = blockDim.y; // number of warps const int iq3 = blockIdx.z; const int iq2 = blockIdx.y; const int iq1 = blockIdx.x * Q; const int D2 = D/2; - const int N4 = WARP_SIZE; - const int L2 = (D2 + N4 - 1)/N4; const int D16 = D/16; + const int Q16 = Q/16; + const int NW = WARP_SIZE; + const int SH = (C + D); // shared memory per simdgroup in (half) - const int T = D + n_warps*(D + 1*C); // shared memory size per query in half - const int T2 = T/2; // shared memory size per query in half2 - - const half scale_h = __float2half(scale); + const int T = D + num_warps*SH; // shared memory size per query in (half) + const int T2 = T/2; // shared memory size per query in (half2) extern __shared__ half __flash_attn_f16_shmem[]; // pq - half * pq = (half *) (__flash_attn_f16_shmem + 0*D); - half2 * pq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); - half * ps = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D); - half2 * ps2 = (half2 *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 1*D); - half * ss = (half *) (__flash_attn_f16_shmem + warp_id*(D + 1*C) + 2*D); + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half16x16_acc lo[Q16][D16]; - for (int i = 0; i < L2; ++i) { - // load heads from Q to shared memory - for (int j = warp_id; j < Q; j += n_warps) { - const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + // load heads from Q to shared memory + for (int64_t j = warp_id; j < Q; j += num_warps) { + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = lane_id; i < D2; i += NW) { if (iq1 + j < ne01) { - pq2[j*T2 + N4*i + lane_id] = __float22half2_rn(q2[N4*i + lane_id]); + sq2[j*T2 + i] = __float22half2_rn(q2[i]); } else { - pq2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); + sq2[j*T2 + i] = make_half2(0.0, 0.0); } } - - // zero out shared memory - for (int j = 0; j < Q; ++j) { - ps2[j*T2 + N4*i + lane_id] = make_half2(0.0, 0.0); - } } - if (lane_id < C) { - for (int j = 0; j < Q; ++j) { - ss[j*T + 0 + lane_id] = 0.0; + // zero out lo + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + } + } + + // zero out shared memory SH + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = lane_id; i < SH; i += NW) { + ss[j*T + i] = 0.0; } } __syncthreads(); -#if 0 + { - half S[Q] = { 0.0 }; // could be half2 S[Q/2] = how fill this array with zeros?? - half M[Q] = { -INFINITY }; // could be half2 M[Q/2] = better register utilization + float S[Q]; + float M[Q]; + + for(int i = 0; i < Q;i ++) { + S[i] = 0.0f; + M[i] = -INFINITY; + } // assume K and V are same shape const int ne22 = ne12; @@ -6265,162 +6277,252 @@ static __global__ void flash_attn_ext_f16( const int iv2 = iq2 / rv2; const int iv3 = iq3 / rv3; - // TODO: this can be improved - float * mp[Q]; - - { - const int ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - for (int j = 0; j < Q; ++j) { - if (iq1 + j < ne01) { - mp[j] = (float *)(mask + ((ir + j)%ne31) * nb31); - } else { - mp[j] = nullptr; - } + // load the queries from shared memory into local memory + half16x16_a mq[Q16][D16]; + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); } } - for (int iic = C*warp_id; iic < ne11; iic += C*n_warps) { - // skip -INF blocks - // TODO: double-check this - { - float smc = -INFINITY; + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - for (int j = 0; j < Q; ++j) { - const float mc = mp[j] ? mp[j][iic + lane_id] : -INFINITY; - smc = warp_reduce_max(max(smc, mc)); - } - - if (smc == -INFINITY) { - continue; - } - } + // pointer to the mask + const float * mp = (const float *) (mask + (ir%ne31)*nb31); + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*warp_id; ic < ne11; ic += C*num_warps) { // Q*K^T { - half16x16_a mq; - half16x16_b mk; - half16x16_acc mqk; - for (int cc = 0; cc < C/16; ++cc) { - nvcuda::wmma::fill_fragment(mqk, 0); - - const half * pk = (const half *) (k + ((iic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); - - for(int i = 0; i < D16;i ++) { - nvcuda::wmma::load_matrix_sync(mq, pq + i*16, T); - nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); - nvcuda::wmma::mma_sync(mqk, mq, mk, mqk); + half16x16_acc mqk[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::fill_fragment(mqk[j], 0); } - nvcuda::wmma::store_matrix_sync(ss + 16*cc, mqk, T, nvcuda::wmma::mem_col_major); + const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t i = 0; i < D16; ++i) { + half16x16_bT mk; // transposed key + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); // transpose + + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); + } + } + + // mqk = mqk*scale + mask + for (int64_t j = 0; j < Q16; ++j) { + const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc; + int64_t msk_ne_row = nb31/sizeof(float); + for (uint32_t i = 0; i < mqk[j].num_elements; i++) { + int msk_col = i % 16; + int msk_row = i / 16; + mqk[j].x[i] = __float2half(scale * __half2float(mqk[j].x[i]) + msk_p[msk_col + msk_row*msk_ne_row]); + } + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_col_major); + } } } + // used to detect blocks full of -INF + float smax = -INFINITY; + // online softmax - for (int j = 0; j < Q; ++j) { - const int p = lane_id; + if (C == 32) { + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = lane_id; - const half s = ss[j*T + p]*scale_h + __float2half(mp[j][iic + p]); + const float m = M[j]; + const float s = __half2float(ss[j*T + p]); - half m = M[j]; + smax = warp_reduce_max(max(smax, s)); + M[j] = warp_reduce_max(max(M[j], s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); + const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); + const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); - const half ms = __hisinf(m) == -1 ? 0.0 : hexp(m - M[j]); - const half vs = __hisinf(s) == -1 ? 0.0 : hexp(s - M[j]); + S[j] = S[j]*ms + warp_reduce_sum(vs); - S[j] = S[j]*ms + warp_reduce_sum(vs); - - for (int i = 0; i < L2; ++i) { - ps2[j*T2 + N4*i + lane_id] *= __half2half2(ms); - } - - ss[j*T + p] = vs; - } - - __syncthreads(); - - // (Q*K^T)*V - { - half16x16_acc mqkv; - half16x16_a mqk; - half16x16_b mv; - - for (int64_t i = 0; i < D16; ++i) { - nvcuda::wmma::fill_fragment(mqkv, 0); - - for (int cc = 0; cc < C/16; ++cc) { - const half * pv = (const half *) ((const char *) v + ((iic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); - - nvcuda::wmma::load_matrix_sync(mqk, ss + cc*16, T); - nvcuda::wmma::load_matrix_sync(mv, pv + i*16, nb21/sizeof(half)); - - nvcuda::wmma::mma_sync(mqkv, mqk, mv, mqkv); + // create a QxQ diagonal matrix for rescaling the output + if (p == j) { + ss[j*T + C + j] = __float2half(ms); } - nvcuda::wmma::store_matrix_sync(ps + i*16, mqkv, T, nvcuda::wmma::mem_col_major); + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = __float2half(vs); + } + } else { + for (int64_t j = 0; j < Q; ++j) { + const float m = M[j]; + + for (int64_t p = lane_id; p < C; p += NW) { + const float s = __half2float(ss[j*T + p]); + + smax = warp_reduce_max(max(smax, s)); + M[j] = warp_reduce_max(max(M[j], s)); + } + + const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]); + + S[j] = S[j]*ms; + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + } + + for (int64_t p = lane_id; p < C; p += NW) { + const float s = ss[j*T + p]; + + const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]); + + S[j] = S[j] + warp_reduce_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = __float2half(vs); + } + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mm; + half16x16_b zro; + + nvcuda::wmma::fill_fragment(zro, 0.0); + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], mm, zro, lo[j][i]); + } + } + + // O = O + (Q*K^T)*V + { + for (int cc = 0; cc < C/16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (int64_t i = 0; i < D16; ++i) { + half16x16_b mk; + nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half)); + + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a mv; + nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]); + } + } } } } + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (int64_t j = 0; j < Q; ++j) { if (lane_id == 0) { - ss[j*T + 0] = S[j]; - ss[j*T + 1] = M[j]; + ss[j*T + 0] = __float2half(S[j]); + ss[j*T + 1] = __float2half(M[j]); } } } - __syncthreads(); + // reduce the warps sequentially + for (int64_t sg = 1; sg < num_warps; ++sg) { + float S = 0.0f; + float M = -INFINITY; - // reduce the warps - // TODO: try parallel reduce - if (warp_id == 0) { - half S = 0.0; - half M = __float2half(-INFINITY); + __syncthreads(); - for (int64_t sg = 1; sg < n_warps; ++sg) { + // each simdgroup stores its output to shared memory, reusing sq + if (warp_id == sg) { + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); + } + } + } + + __syncthreads(); + + // the first simdgroup accumulates the results from the other simdgroups + if (warp_id == 0) { for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 0]; - const half S1 = ss[j*T + sg*(D + 1*C) + 0]; + const float S0 = __half2float(ss[j*T + 0]); + const float S1 = __half2float(ss[j*T + sg*SH + 0]); - const half M0 = ss[j*T + 1]; - const half M1 = ss[j*T + sg*(D + 1*C) + 1]; + const float M0 = __half2float(ss[j*T + 1]); + const float M1 = __half2float(ss[j*T + sg*SH + 1]); - M = __hmax(M0, M1); + M = max(M0, M1); - const half ms0 = hexp(M0 - M); - const half ms1 = hexp(M1 - M); + const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M); + const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M); S = S0*ms0 + S1*ms1; if (lane_id == 0) { - ss[j*T + 0] = S; - ss[j*T + 1] = M; - } + ss[j*T + 0] = __float2half(S); + ss[j*T + 1] = __float2half(M); - for (int64_t i = 0; i < L2; ++i) { - ps2[j*T2 + N4*i + lane_id] = ps2[j*T2 + N4*i + lane_id]*__half2half2(ms0) + ps2[j*T2 + sg*(D + 1*C)/4 + N4*i + lane_id]*__half2half2(ms1); + ss[j*T + C + j ] = __float2half(ms0); + ss[j*T + C + j + sg*SH] = __float2half(ms1); + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (int64_t j = 0; j < Q16; ++j) { + half16x16_a ms0; + half16x16_a ms1; + half16x16_b t; + half16x16_acc t2; + + nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); + + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(t2, ms1, t, t2); + + // t <- lo + for (uint32_t k = 0; k < t.num_elements; k++) { + t.x[k] = lo[j][i].x[k]; + } + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); } } } } - __syncthreads(); - - float2 * dst2 = (float2 *) kqv; - + // store result to shared memory (reuse sq) if (warp_id == 0) { - for (int j = 0; j < Q && iq1 + j < ne01; ++j) { - half2 S = __half2half2(ss[j*T + 0]); - - for (int i = 0; i < L2; ++i) { - dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + N4*i + lane_id] = __half22float2(ps2[j*T2 + N4*i + lane_id]/S); + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_col_major); + } + } + } + + float2 * dst2 = (float2 *) dst; + + // final rescale with 1/S and store to global memory + if (warp_id == 0) { + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const float S = __half2float(ss[j*T + 0]); + + for (int64_t i = lane_id; i < D2; i += NW) { + dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i] = __half22float2(sq2[j*T2 + i]); + dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].x /= S; + dst2[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D2 + i].y /= S; } } } -#endif } #else template // D head size, Q queries per block, C cache items per blocks @@ -6451,7 +6553,6 @@ static __global__ void flash_attn_ext_f16( int ne1, int ne2, int ne3) { - bad_arch(); } #endif @@ -10446,9 +10547,9 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - const int nwarps = Q->ne[1] < 4 ? 12 : 4; const int nqpb = 16; // queries per block const int ncpw = 32; // cache values per warp (does not work for other values) + const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, 32)) : 4; dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); dim3 block_dim(32, nwarps, 1); @@ -10457,6 +10558,23 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * printf("shared memory: %d bytes [%i, %i, %i]\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2]); switch (Q->ne[0]) { + case 16: + flash_attn_ext_f16<16, 16, 32> + <<>> ( + (const char *) src0_extra->data_device[g_main_device], // Query + (const char *) src1_extra->data_device[g_main_device], // Key + (const char *) src2_extra->data_device[g_main_device], // Value + (const char *) src3_extra->data_device[g_main_device], // Mask + (float *) dst_extra->data_device[g_main_device], // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask->ne[1], mask->nb[1], + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; case 64: flash_attn_ext_f16<64, 16, 32> <<>> ( diff --git a/tests/test-flash-attention.cpp b/tests/test-flash-attention.cpp index 74167ed86..5d83eeabd 100644 --- a/tests/test-flash-attention.cpp +++ b/tests/test-flash-attention.cpp @@ -2,8 +2,6 @@ #include "ggml-alloc.h" #include "ggml-backend.h" -#define GGML_USE_CUBLAS - #ifdef GGML_USE_CUBLAS #include "ggml-cuda.h" #endif @@ -22,6 +20,7 @@ struct test_model { struct ggml_tensor * q; struct ggml_tensor * k; struct ggml_tensor * v; + struct ggml_tensor * msk; ggml_backend_t backend = NULL; ggml_backend_buffer_t buffer = NULL; struct ggml_context * ctx = NULL; @@ -102,59 +101,38 @@ float ggml_tensor_get_f32(const ggml_tensor* tensor, int l, int k = 0, int j = 0 return *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); } -void load_model(test_model & model, bool use_gpu = false) { - float Query[30] = { // [3, 4, 2] - // z0 - 2, 4, 2, - 4, 2, 1, - 4, 1, 3, - 4, 2, 2, +void load_model(test_model & model, int head_dim, int batch_size, int kv_size, int num_heads) { + float* query = new float[head_dim * batch_size * num_heads]; + float* key = new float[head_dim * kv_size * num_heads]; + float* value = new float[head_dim * kv_size * num_heads]; + float* mask = new float[kv_size * batch_size]; - // z1 - 2, 1, 1, - 4, 2, 1, - 1, 1, 3, - 4, 2, 1 - }; + for(int i = 0; i < head_dim*batch_size*num_heads;i ++) { + query[i] = i % 3 ? 2.0f : 1.5f; + } - float Key[24] = { // [3, 4, 2] - // z0 - 2, 4, 2, - 4, 2, 1, - 4, 2, 3, - 1, 2, 1, + for(int i = 0; i < head_dim*kv_size*num_heads;i ++) { + key[i] = i % 3 ? 2.3f : 2.8f; + value[i] = i % 3 ? 3.5f : 1.5f; + } - // z1 - 3, 1, 3, - 4, 2, 1, - 1, 1, 2, - 4, 3, 1 - }; - - float Value[24] = { // [4, 3, 2] - // z0 - 2, 4, 2, 1, - 2, 1, 4, 2, - 1, 4, 2, 3, - - // z1 - 1, 4, 2, 1, - 2, 1, 1, 2, - 1, 4, 3, 3, - }; + for(int i = 0; i < batch_size*kv_size;i ++) { + mask[i] = i % 3 ? 1.0f : 0.0f; + } size_t buffer_size = 0; { - buffer_size += 30 * ggml_type_sizef(GGML_TYPE_F32); // tensor q - buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor k - buffer_size += 24 * ggml_type_sizef(GGML_TYPE_F32); // tensor v + buffer_size += head_dim * batch_size * num_heads * ggml_type_sizef(GGML_TYPE_F32); // tensor q + buffer_size += head_dim * kv_size * num_heads * ggml_type_sizef(GGML_TYPE_F16); // tensor k + buffer_size += head_dim * kv_size * num_heads * ggml_type_sizef(GGML_TYPE_F16); // tensor v + buffer_size += batch_size * kv_size * ggml_type_sizef(GGML_TYPE_F32); // tensor mask buffer_size += 1024; } printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); - int num_tensors = 3; + int num_tensors = 4; struct ggml_init_params params { /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, /*.mem_buffer =*/ NULL, @@ -163,12 +141,10 @@ void load_model(test_model & model, bool use_gpu = false) { // initialize the backend #ifdef GGML_USE_CUBLAS - if (use_gpu) { - fprintf(stderr, "%s: using CUDA backend\n", __func__); - model.backend = ggml_backend_cuda_init(0); - if (!model.backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } + fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); } #endif @@ -183,9 +159,10 @@ void load_model(test_model & model, bool use_gpu = false) { model.ctx = ggml_init(params); // create tensors - model.q = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2); - model.k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 3, 4, 2); - model.v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, 4, 3, 2); + model.q = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, head_dim, batch_size, num_heads); + model.k = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads); + model.v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F16, head_dim, kv_size, num_heads); + model.msk = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, kv_size, batch_size); // create a allocator ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer); @@ -194,12 +171,18 @@ void load_model(test_model & model, bool use_gpu = false) { ggml_allocr_alloc(alloc, model.q); ggml_allocr_alloc(alloc, model.k); ggml_allocr_alloc(alloc, model.v); + ggml_allocr_alloc(alloc, model.msk); - ggml_backend_tensor_set(model.q, Query, 0, ggml_nbytes(model.q)); - ggml_backend_tensor_set(model.k, Key, 0, ggml_nbytes(model.k)); - ggml_backend_tensor_set(model.v, Value, 0, ggml_nbytes(model.v)); + ggml_fp16_t* k_f16 = new ggml_fp16_t[head_dim * kv_size * num_heads]; + ggml_fp16_t* v_f16 = new ggml_fp16_t[head_dim * kv_size * num_heads]; - ggml_allocr_free(alloc); + ggml_fp32_to_fp16_row(key, k_f16, head_dim * kv_size * num_heads); + ggml_fp32_to_fp16_row(value, v_f16, head_dim * kv_size * num_heads); + + ggml_backend_tensor_set(model.q, query, 0, ggml_nbytes(model.q)); + ggml_backend_tensor_set(model.k, k_f16, 0, ggml_nbytes(model.k)); + ggml_backend_tensor_set(model.v, v_f16, 0, ggml_nbytes(model.v)); + ggml_backend_tensor_set(model.msk, mask, 0, ggml_nbytes(model.msk)); } struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * allocr) { @@ -218,7 +201,7 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a struct ggml_cgraph * gf = ggml_new_graph(ctx0); if(!model.naive_attn) { - struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false); + struct ggml_tensor* result = ggml_flash_attn_ext(ctx0, model.q, model.k, model.v, model.msk, 1.0f / sqrtf(model.q->ne[0])); ggml_build_forward_expand(gf, result); } else { struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q); @@ -350,8 +333,7 @@ int main(int argc, char ** argv) ggml_time_init(); - - load_model(model, true); + load_model(model, 16, 32, 128, 2); ggml_backend_buffer_t buf_compute; // for compute struct ggml_allocr * allocr = NULL; @@ -385,7 +367,10 @@ int main(int argc, char ** argv) if(i > 0 && (i % result->ne[0] == 0)) { printf("\n"); } - printf("%2.6f ", data[i]); + if(i > 0 && (i % (result->ne[0] * result->ne[2]) == 0)) { + printf("\n\n"); + } + printf("%2.4f ", data[i]); } }