cuda : batched cuBLAS GEMMs for src0 F16 and src1 F32 (attention ops)

This commit is contained in:
Georgi Gerganov 2023-10-23 20:37:04 +03:00
parent 84d4ca0e47
commit c13fcfbfc0
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -7013,7 +7013,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
} }
static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(!ggml_is_permuted(src0)); GGML_ASSERT(!ggml_is_permuted(src0));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
@ -7023,11 +7024,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne12 = src1->ne[2];
const int64_t nb01 = src0->nb[1]; const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2]; const int64_t nb02 = src0->nb[2];
const int64_t ne12 = src1->ne[2];
CUDA_CHECK(ggml_cuda_set_device(g_main_device)); CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@ -7046,6 +7047,154 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
} }
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
const int64_t ne13 = src1->ne[3];
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
const int64_t ne1 = ggml_nelements(src1);
const int64_t ne = ggml_nelements(dst);
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
int id;
CUDA_CHECK(cudaGetDevice(&id));
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
void * src0_ddq = src0_extra->data_device[g_main_device];
half * src0_as_f16 = (half *) src0_ddq;
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
// convert src1 to fp16
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
size_t src1_as = 0;
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
#if 0
// use cublasGemmEx
{
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
int i03 = i13 / r3;
int i02 = i12 / r2;
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}
}
#else
// use cublasGemmBatchedEx
{
const int ne23 = ne12*ne13;
// TODO: avoid this alloc
void ** src0_ptrs = (void **) malloc(ne23*sizeof(void *));
void ** src1_ptrs = (void **) malloc(ne23*sizeof(void *));
void ** dst_ptrs = (void **) malloc(ne23*sizeof(void *));
for (int i13 = 0; i13 < ne13; ++i13) {
for (int i12 = 0; i12 < ne12; ++i12) {
int i03 = i13 / r3;
int i02 = i12 / r2;
src0_ptrs[i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
src1_ptrs[i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
dst_ptrs [i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
}
}
// allocate device memory for pointers
void ** src0_ptrs_as = nullptr;
void ** src1_ptrs_as = nullptr;
void ** dst_ptrs_as = nullptr;
CUDA_CHECK(cudaMalloc(&src0_ptrs_as, ne23*sizeof(void *)));
CUDA_CHECK(cudaMalloc(&src1_ptrs_as, ne23*sizeof(void *)));
CUDA_CHECK(cudaMalloc(& dst_ptrs_as, ne23*sizeof(void *)));
// copy pointers to device
CUDA_CHECK(cudaMemcpy(src0_ptrs_as, src0_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy(src1_ptrs_as, src1_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUDA_CHECK(cudaMemcpy( dst_ptrs_as, dst_ptrs, ne23*sizeof(void *), cudaMemcpyHostToDevice));
CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha_f16, (void **) src0_ptrs_as, CUDA_R_16F, nb01/sizeof(half),
(void **) src1_ptrs_as, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, (void **) dst_ptrs_as, CUDA_R_16F, ne01,
ne23,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// free device memory for pointers
CUDA_CHECK(cudaFree(src0_ptrs_as));
CUDA_CHECK(cudaFree(src1_ptrs_as));
CUDA_CHECK(cudaFree( dst_ptrs_as));
free(src0_ptrs);
free(src1_ptrs);
free( dst_ptrs);
}
#endif
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
}
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
@ -7058,10 +7207,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
} }
} }
// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// KQ
ggml_cuda_mul_mat_vec_p021(src0, src1, dst); ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
} else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV
ggml_cuda_mul_mat_vec_nc(src0, src1, dst); ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {