cuda: mask as fp16

This commit is contained in:
FSSRepo 2024-01-31 16:22:11 -05:00
parent 3df0b8d47c
commit fd878f71ed

View file

@ -6529,7 +6529,7 @@ static __global__ void flash_attn_ext_f16(
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
// pointer to the mask // pointer to the mask
const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr; const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
// loop over the KV cache // loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns // each simdgroup handles blocks of Q rows and C columns
@ -6555,12 +6555,9 @@ static __global__ void flash_attn_ext_f16(
// mqk = mqk*scale + mask // mqk = mqk*scale + mask
for (int64_t j = 0; j < Q16; ++j) { for (int64_t j = 0; j < Q16; ++j) {
// const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
// int64_t msk_ne_row = nb31/sizeof(float);
for (uint32_t i = 0; i < mqk[j].num_elements; i++) { for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
// int msk_col = i % 16; // TODO: process mask
// int msk_row = i / 16; mqk[j].x[i] = __float2half(scale) * mqk[j].x[i];
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]);
} }
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
} }
@ -9216,7 +9213,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
src1_dfloat = src1_dfloat_a.alloc(ne00); src1_dfloat = src1_dfloat_a.alloc(ne00);
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
ne00, 1, sizeof(float), 0, 0, ne00, 1, sizeof(float), 0, 0,
ne00, 1, sizeof(half), 0, 0, stream); ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream);
} }
#else #else
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
@ -10891,19 +10888,18 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(Q->type == GGML_TYPE_F32);
GGML_ASSERT(K->type == GGML_TYPE_F16); GGML_ASSERT(K->type == GGML_TYPE_F16);
GGML_ASSERT(V->type == GGML_TYPE_F16); GGML_ASSERT(V->type == GGML_TYPE_F16);
if(mask) {
GGML_ASSERT(mask->type == GGML_TYPE_F32);
}
GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32);
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU); GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
GGML_ASSERT(K->backend == GGML_BACKEND_GPU); GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
GGML_ASSERT(V->backend == GGML_BACKEND_GPU); GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
if(mask) {
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU);
}
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU); GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big");
ggml_cuda_set_device(g_main_device); ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@ -10925,7 +10921,6 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
dim3 block_dim(32, nwarps, 1); dim3 block_dim(32, nwarps, 1);
int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
printf("shared memory: %d bytes [%i, %i, %i] scale = %f\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2], scale);
switch (Q->ne[0]) switch (Q->ne[0])
{ {
case 16: case 16: