cuda : use CUBLAS_COMPUTE_32F to speed-up and avoid dst cpy

This commit is contained in:
Georgi Gerganov 2023-10-27 18:13:54 +03:00
parent c8d6a1f34a
commit 3b9ea655d4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -6385,27 +6385,19 @@ inline void ggml_cuda_op_mul_mat_cublas(
} }
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16; const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
size_t dst_as = 0; const float alpha = 1.0f;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as); const float beta = 0.0f;
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream)); CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00, &alpha, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10, src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16, CUDA_R_16F, ldc, &beta, dst_dd_i, CUDA_R_32F, ldc,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
ggml_cuda_pool_free(dst_f16, dst_as);
if (src0_as != 0) { if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as); ggml_cuda_pool_free(src0_as_f16, src0_as);
} }
@ -7189,9 +7181,6 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
size_t dst_as = 0;
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0); GGML_ASSERT(ne13 % ne03 == 0);
@ -7199,8 +7188,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const int64_t r2 = ne12/ne02; const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03; const int64_t r3 = ne13/ne03;
const half alpha_f16 = 1.0f; const float alpha = 1.0f;
const half beta_f16 = 0.0f; const float beta = 0.0f;
#if 0 #if 0
// use cublasGemmEx // use cublasGemmEx
@ -7213,10 +7202,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), &alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, &beta, ( char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] , CUDA_R_32F, ne01,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} }
} }
@ -7228,11 +7217,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA &alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC &beta, ( char *) dst_ddf, CUDA_R_32F, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13, ne12*ne13,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else { } else {
// use cublasGemmBatchedEx // use cublasGemmBatchedEx
@ -7249,7 +7238,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3]; ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2; ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2; ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] ;
} }
} }
@ -7269,11 +7258,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
&alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half), &alpha, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float), (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
&beta_f16, ( void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01, &beta, ( void **) (ptrs_as + 2*ne23), CUDA_R_32F, ne01,
ne23, ne23,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// free device memory for pointers // free device memory for pointers
@ -7282,11 +7271,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
} }
#endif #endif
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
ggml_cuda_pool_free(src1_as_f16, src1_as); ggml_cuda_pool_free(src1_as_f16, src1_as);
ggml_cuda_pool_free(dst_f16, dst_as);
} }
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {