diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 330fc6290..e7bf95bd1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6495,8 +6495,8 @@ static __global__ void flash_attn_ext_f16( half M[Q]; for(int i = 0; i < Q; i++) { - S[i] = 0.0f; - M[i] = -INFINITY; + S[i] = __float2half(0.0f); + M[i] = __float2half(-INFINITY); } // assume K and V are same shape @@ -6579,7 +6579,7 @@ static __global__ void flash_attn_ext_f16( } // used to detect blocks full of -INF - half smax = -INFINITY; + half smax = __float2half(-INFINITY); // online softmax if (C == 32) { @@ -6592,8 +6592,8 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(__hmax(smax, s)); M[j] = warp_reduce_max(__hmax(M[j], s)); - const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]); - const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]); + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); S[j] = S[j]*ms + warp_reduce_sum(vs); @@ -6612,33 +6612,38 @@ static __global__ void flash_attn_ext_f16( for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - smax = warp_reduce_max(__hmax(smax, s)); - M[j] = warp_reduce_max(__hmax(M[j], s)); + smax = __hmax(smax, s); + M[j] = __hmax(M[j], s); } - const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]); + smax = warp_reduce_max(smax); + M[j] = warp_reduce_max(M[j]); - S[j] = S[j]*ms; + const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]); // create a QxQ diagonal matrix for rescaling the output if (lane_id == j) { ss[j*T + C + j] = ms; } + // local sum + half ls = 0.0f; + for (int64_t p = lane_id; p < C; p += NW) { const half s = ss[j*T + p]; - const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]); + const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]); - S[j] = S[j] + warp_reduce_sum(vs); + ls += vs; // the P matrix from the paper (Q rows, C columns) ss[j*T + p] = vs; } + + S[j] = S[j]*ms + warp_reduce_sum(ls); } } - // skip -INF blocks if (__hisinf(smax)) { continue; @@ -6669,15 +6674,19 @@ static __global__ void flash_attn_ext_f16( for (int cc = 0; cc < C/16; ++cc) { const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + half16x16_b mk[D16]; for (int64_t i = 0; i < D16; ++i) { - half16x16_b mk; - nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half)); + nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half)); + } - for (int64_t j = 0; j < Q16; ++j) { - half16x16_a mv; - nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T); + half16x16_a mv[Q16]; + for (int64_t j = 0; j < Q16; ++j) { + nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T); + } - nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]); + for (int64_t j = 0; j < Q16; ++j) { + for (int64_t i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]); } } } @@ -6695,8 +6704,8 @@ static __global__ void flash_attn_ext_f16( // reduce the warps sequentially for (int64_t sg = 1; sg < num_warps; ++sg) { - half S = 0.0f; - half M = -INFINITY; + half S = __float2half(0.0f); + half M = __float2half(-INFINITY); __syncthreads(); @@ -6722,8 +6731,8 @@ static __global__ void flash_attn_ext_f16( M = __hmax(M0, M1); - const half ms0 = __hisinf(M0) ? 0.0f : expf(M0 - M); - const half ms1 = __hisinf(M1) ? 0.0f : expf(M1 - M); + const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M); + const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M); S = S0*ms0 + S1*ms1; @@ -6770,8 +6779,6 @@ static __global__ void flash_attn_ext_f16( } } - // float2 * dst2 = (float2 *) dst; - // final rescale with 1/S and store to global memory if (warp_id == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { @@ -9637,7 +9644,7 @@ static void ggml_cuda_op_soft_max( const int64_t ne00 = src0->ne[0]; const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1; + const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded! float scale = 1.0f; memcpy(&scale, dst->op_params, sizeof(float)); @@ -10932,7 +10939,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * memcpy(&scale, KQV->op_params, sizeof(float)); #define NQPB 16 -#define NCPW 32 +#define NCPW 128 const int nqpb = NQPB; // queries per block const int ncpw = NCPW; // cache values per warp (does not work for other values) diff --git a/ggml.c b/ggml.c index 59a4c05a1..ebd9c6b34 100644 --- a/ggml.c +++ b/ggml.c @@ -5089,7 +5089,7 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[2] == 1); GGML_ASSERT(mask->ne[3] == 1); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } bool is_node = false; diff --git a/llama.cpp b/llama.cpp index fe2583966..2330efff5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6881,7 +6881,8 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + // note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext) + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); //kv_self.n = llama_kv_cache_cell_max(kv_self); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ff207e21b..e23384eee 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_leaky_relu()); #if 1 - for (int hs : { 64, 80, 128, }) { + for (int hs : { 128, 64, 80, }) { for (int nh : { 32, }) { for (int kv : { 512, 1024, 2048, 4096, }) { for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {