From fd878f71ed370eb34b85f89e27f07821a9b2c10b Mon Sep 17 00:00:00 2001 From: FSSRepo Date: Wed, 31 Jan 2024 16:22:11 -0500 Subject: [PATCH] cuda: mask as fp16 --- ggml-cuda.cu | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 35e2af0f4..86afb0133 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6529,7 +6529,7 @@ static __global__ void flash_attn_ext_f16( const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; // pointer to the mask - const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr; + const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns @@ -6555,12 +6555,9 @@ static __global__ void flash_attn_ext_f16( // mqk = mqk*scale + mask for (int64_t j = 0; j < Q16; ++j) { - // const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc; - // int64_t msk_ne_row = nb31/sizeof(float); for (uint32_t i = 0; i < mqk[j].num_elements; i++) { - // int msk_col = i % 16; - // int msk_row = i / 16; - mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]); + // TODO: process mask + mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; } nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); } @@ -9216,7 +9213,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec( src1_dfloat = src1_dfloat_a.alloc(ne00); ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, ne00, 1, sizeof(float), 0, 0, - ne00, 1, sizeof(half), 0, 0, stream); + ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream); } #else const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion @@ -10891,19 +10888,18 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16); - if(mask) { - GGML_ASSERT(mask->type == GGML_TYPE_F32); - } GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); GGML_ASSERT(K->backend == GGML_BACKEND_GPU); GGML_ASSERT(V->backend == GGML_BACKEND_GPU); - if(mask) { - GGML_ASSERT(mask->backend == GGML_BACKEND_GPU); - } GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big"); + ggml_cuda_set_device(g_main_device); const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; @@ -10925,7 +10921,6 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * dim3 block_dim(32, nwarps, 1); int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); - printf("shared memory: %d bytes [%i, %i, %i] scale = %f\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2], scale); switch (Q->ne[0]) { case 16: