cuBLAS: refactor and optimize f16 mat mul performance (#1259)

* cuBLAS: refactor, convert fp16 to fp32 on device

* cuBLAS: use multiple streams, choose smartly between mul_mat_q and mul_mat_f16

* fix build

* cuBLAS: update block_q5_1
This commit is contained in:
slaren 2023-05-01 18:11:07 +02:00 committed by GitHub
parent ea3a0ad6b6
commit 58b367c2d7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 480 additions and 259 deletions

252
ggml.c
View file

@ -135,14 +135,6 @@ inline static void* ggml_aligned_malloc(size_t size) {
#define UNUSED(x) (void)(x)
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
#define GGML_ASSERT(x) \
do { \
if (!(x)) { \
fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
abort(); \
} \
} while (0)
#if defined(GGML_USE_ACCELERATE)
#include <Accelerate/Accelerate.h>
#elif defined(GGML_USE_OPENBLAS)
@ -370,6 +362,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
return GGML_FP32_TO_FP16(x);
}
void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) {
for (size_t i = 0; i < n; i++) {
y[i] = GGML_FP16_TO_FP32(x[i]);
}
}
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) {
size_t i = 0;
#if defined(__F16C__)
for (; i + 7 < n; i += 8) {
__m256 x_vec = _mm256_loadu_ps(x + i);
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storeu_si128((__m128i *)(y + i), y_vec);
}
for(; i + 3 < n; i += 4) {
__m128 x_vec = _mm_loadu_ps(x + i);
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storel_epi64((__m128i *)(y + i), y_vec);
}
#endif
for (; i < n; i++) {
y[i] = GGML_FP32_TO_FP16(x[i]);
}
}
//
// timing
//
@ -4325,12 +4343,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
}
// initialize cuBLAS
#if defined(GGML_USE_CUBLAS)
#if defined(GGML_USE_CUBLAS)
ggml_init_cublas();
#elif defined(GGML_USE_CLBLAST)
#elif defined(GGML_USE_CLBLAST)
ggml_cl_init();
#endif
#endif
is_first_call = false;
}
@ -8101,7 +8118,7 @@ static void ggml_compute_forward_rms_norm(
// ggml_compute_forward_mul_mat
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
// helper function to determine if it is better to use BLAS or not
// for large matrices, BLAS is faster
static bool ggml_compute_forward_mul_mat_use_blas(
@ -8117,12 +8134,9 @@ static bool ggml_compute_forward_mul_mat_use_blas(
const int64_t ne1 = dst->ne[1];
// TODO: find the optimal values for these
if (
#if !defined(GGML_USE_CUBLAS)
ggml_is_contiguous(src0) &&
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
#endif
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
/*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
return true;
@ -8130,7 +8144,6 @@ static bool ggml_compute_forward_mul_mat_use_blas(
return false;
}
#endif
static void ggml_compute_forward_mul_mat_f32(
@ -8146,7 +8159,7 @@ static void ggml_compute_forward_mul_mat_f32(
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
const int64_t ne10 = src1->ne[0];
#endif
const int64_t ne11 = src1->ne[1];
@ -8203,7 +8216,16 @@ static void ggml_compute_forward_mul_mat_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;
}
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
if (params->ith != 0) {
return;
@ -8217,43 +8239,13 @@ static void ggml_compute_forward_mul_mat_f32(
return;
}
#if defined(GGML_USE_CUBLAS)
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
size_t x_size, y_size, d_size;
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
#if !defined(GGML_USE_CUBLAS)
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
#endif
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy data to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
// compute
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, ne00,
d_Y, ne10,
&beta, d_D, ne01));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#elif defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CLBLAST)
// zT = y * xT
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
ne11, ne01, ne10,
@ -8270,12 +8262,6 @@ static void ggml_compute_forward_mul_mat_f32(
#endif
}
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
#endif
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return;
@ -8405,7 +8391,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;
}
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
GGML_ASSERT(nb10 == sizeof(float));
@ -8421,37 +8416,8 @@ static void ggml_compute_forward_mul_mat_f16_f32(
return;
}
#if defined(GGML_USE_CUBLAS)
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
size_t x_size, y_size, d_size;
ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
#if defined(GGML_USE_CUBLAS)
// copy src0 while converting src1
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream));
// with cuBlAS, instead of converting src0 to fp32, we convert src1 to fp16
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + (ne11 * ne10) * (i03 * ne02 + i02);
{
size_t id = 0;
for (int64_t i01 = 0; i01 < ne11; ++i01) {
for (int64_t i00 = 0; i00 < ne10; ++i00) {
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
}
}
assert(id*sizeof(ggml_fp16_t) <= params->wsize);
}
#else
float * const wdata = params->wdata;
{
size_t id = 0;
@ -8463,28 +8429,8 @@ static void ggml_compute_forward_mul_mat_f16_f32(
assert(id*sizeof(float) <= params->wsize);
}
#endif
#if defined(GGML_USE_CUBLAS)
const ggml_fp16_t * y = (ggml_fp16_t *) wdata;
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
// copy data to device
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
// compute
CUBLAS_CHECK(
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, CUDA_R_16F, ne00,
d_Y, CUDA_R_16F, ne10,
&beta, d_D, CUDA_R_32F, ne01,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#elif defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CLBLAST)
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@ -8513,12 +8459,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
}
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
#endif
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
return;
@ -8671,7 +8611,16 @@ static void ggml_compute_forward_mul_mat_q_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;
}
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
if (params->ith != 0) {
return;
@ -8685,25 +8634,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
return;
}
#if defined(GGML_USE_CUBLAS)
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
size_t x_size, y_size, d_size, q_size;
float * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
float * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
void * d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
const dequantize_row_q_cuda_t dequantize_row_q_cuda = ggml_get_dequantize_row_q_cuda(type);
GGML_ASSERT(dequantize_row_q_cuda != NULL);
#else
float * const wdata = params->wdata;
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
@ -8711,14 +8643,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
#if defined(GGML_USE_CUBLAS)
// copy and dequantize on device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2));
dequantize_row_q_cuda(d_Q, d_X, x_ne, g_cudaStream2);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2));
#elif defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CLBLAST)
const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
#else
{
@ -8734,24 +8659,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
const float * x = wdata;
#endif
#if defined(GGML_USE_CUBLAS)
// copy data to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream));
// wait for dequantization
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0));
// compute
CUBLAS_CHECK(
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, ne00,
d_Y, ne10,
&beta, d_D, ne01));
// copy data to host
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#elif defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_CLBLAST)
// zT = y * xT
ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
ne11, ne01, ne10,
@ -8769,13 +8677,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
}
#if defined(GGML_USE_CUBLAS)
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
ggml_cuda_pool_free(d_X, x_size);
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
ggml_cuda_pool_free(d_Q, q_size);
#endif
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
return;
@ -11759,18 +11660,21 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
size_t cur = 0;
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
}
else
#endif
if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
#if defined(GGML_USE_CUBLAS)
// with cuBLAS, we need memory for the full 3D / 4D data of src1
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
#else
// here we need memory just for single 2D matrix from src0
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
#endif
} else {
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
}
@ -11779,13 +11683,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
#endif
} else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
cur = 0;
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1;
}
#endif
} else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CLBLAST)
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
node->n_tasks = 1;
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);