diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 476e4fbf1..443378c6c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1,3 +1,4 @@ +#include #include #include #include @@ -253,6 +254,7 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre dequantize_block_q8_0<<>>(vx, y); } +// TODO: optimize static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { const half * x = (const half *) vx; @@ -345,26 +347,31 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { CUDA_CHECK(cudaFree(ptr)); } +#define GGML_CUDA_MAX_STREAMS 8 +#define GGML_CUDA_MAX_EVENTS 64 static cublasHandle_t g_cublasH = nullptr; -static cudaStream_t g_cudaStream = nullptr; -static cudaStream_t g_cudaStream2 = nullptr; -static cudaEvent_t g_cudaEvent = nullptr; +static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr }; +static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr }; +static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr }; void ggml_init_cublas() { if (g_cublasH == nullptr) { - // create cublas handle, bind a stream - CUBLAS_CHECK(cublasCreate(&g_cublasH)); - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking)); - CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream)); - // enable tensor cores - CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TENSOR_OP_MATH)); + // create streams + for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) { + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking)); + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking)); + } + // create events + for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) { + CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming)); + } - // create additional stream and event for synchronization - CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream2, cudaStreamNonBlocking)); - CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvent, cudaEventDisableTiming)); + // create cublas handle + CUBLAS_CHECK(cublasCreate(&g_cublasH)); + CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH)); // configure logging to stdout - // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL)); + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); } } @@ -433,39 +440,141 @@ static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * const int x_ne = ne01 * ne00; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; + const int n_mm = ne03 * ne02; size_t x_size, y_size, d_size; - float * d_X = (float *) ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float * d_Y = (float *) ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float * d_D = (float *) ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); + float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); + float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { + int i = i03*ne02 + i02; + cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS]; + + float * c_X = d_X + i * x_ne; + float * c_Y = d_Y + i * y_ne; + float * c_D = d_D + i * d_ne; + // copy data to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_X, src0, i03, i02, g_cudaStream)); - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream)); + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); CUBLAS_CHECK( cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - &alpha, d_X, ne00, - d_Y, ne10, - &beta, d_D, ne01)); + &alpha, c_X, ne00, + c_Y, ne10, + &beta, c_D, ne01)); - // copy data to host + // copy dst to host float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); } } - CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + CUDA_CHECK(cudaDeviceSynchronize()); ggml_cuda_pool_free(d_X, x_size); ggml_cuda_pool_free(d_Y, y_size); ggml_cuda_pool_free(d_D, d_size); } -static void ggml_cuda_mul_mat_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) { + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + const int n_mm = ne03 * ne02; + + size_t x_size, y_size, d_size; + half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size); + half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size); + float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); + + bool src1_cont_rows = nb10 == sizeof(float); + bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + int i = i03*ne02 + i02; + cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS]; + + half * c_X = d_X + i * x_ne; + half * c_Y = d_Y + i * y_ne; + float * c_D = d_D + i * d_ne; + + // copy src0 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream)); + + // convert src1 to fp16 + // TODO: use multiple threads + ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02); + char * src1i = (char *) src1->data + i03*nb13 + i02*nb12; + if (src1_cont_rows) { + if (src1_cont_cols) { + ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11); + } + else { + for (int64_t i01 = 0; i01 < ne11; i01++) { + ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10); + } + } + } + else { + for (int64_t i01 = 0; i01 < ne11; i01++) { + for (int64_t i00 = 0; i00 < ne10; i00++) { + // very slow due to no inlining + tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10)); + } + } + } + + // copy src1 to device + CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, CUDA_R_16F, ne00, + c_Y, CUDA_R_16F, ne10, + &beta, c_D, CUDA_R_32F, ne01, + CUBLAS_COMPUTE_32F_FAST_16F, + CUBLAS_GEMM_DEFAULT)); + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + } + } + + CUDA_CHECK(cudaDeviceSynchronize()); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_Y, y_size); + ggml_cuda_pool_free(d_D, d_size); +} + +static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; @@ -483,46 +592,58 @@ static void ggml_cuda_mul_mat_q(const ggml_tensor * src0, const ggml_tensor * sr const int x_ne = ne01 * ne00; const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; + const int n_mm = ne03 * ne02; + const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); size_t x_size, y_size, d_size, q_size; - float * d_X = (float *) ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); - float * d_Y = (float *) ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); - float * d_D = (float *) ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); - void * d_Q = (void *) ggml_cuda_pool_malloc(ggml_type_size(type) * x_ne / ggml_blck_size(type), &q_size); + float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); + float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); + char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type); - GGML_ASSERT(to_fp32_cuda != NULL); + GGML_ASSERT(to_fp32_cuda != nullptr); for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - // copy and convert to fp32 on device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, src0, i03, i02, g_cudaStream2)); + int i = i03*ne02 + i02; + cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS]; + cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS]; + cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS]; - to_fp32_cuda(d_Q, d_X, x_ne, g_cudaStream2); + float * c_X = d_X + i * x_ne; + float * c_Y = d_Y + i * y_ne; + float * c_D = d_D + i * d_ne; + char * c_Q = d_Q + i * q_sz; + + // copy src0 and convert to fp32 on device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); + to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); CUDA_CHECK(cudaGetLastError()); - CUDA_CHECK(cudaEventRecord(g_cudaEvent, g_cudaStream2)); + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); - // copy data to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Y, src1, i03, i02, g_cudaStream)); + // copy src1 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); // wait for conversion - CUDA_CHECK(cudaStreamWaitEvent(g_cudaStream, g_cudaEvent, 0)); + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); CUBLAS_CHECK( cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - &alpha, d_X, ne00, - d_Y, ne10, - &beta, d_D, ne01)); + &alpha, c_X, ne00, + c_Y, ne10, + &beta, c_D, ne01)); - // copy data to host + // copy dst to host float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); - CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream)); + CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); } } - CUDA_CHECK(cudaStreamSynchronize(g_cudaStream)); + CUDA_CHECK(cudaDeviceSynchronize()); ggml_cuda_pool_free(d_X, x_size); ggml_cuda_pool_free(d_Y, y_size); ggml_cuda_pool_free(d_D, d_size); @@ -547,18 +668,48 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te return false; } -void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) { + size_t src0_sz = ggml_nbytes(src0); + size_t src1_sz = ggml_nbytes(src1); + + // mul_mat_q: src0 is converted to fp32 on device + size_t mul_mat_q_transfer = src0_sz + src1_sz; + + // mul_mat_f16: src1 is converted to fp16 on cpu + size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1); + + // choose the smaller one to transfer to the device + // TODO: this is not always the best choice due to the overhead of converting to fp16 + return mul_mat_f16_transfer < mul_mat_q_transfer; +} + +void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) { GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst)); - const ggml_type type = src0->type; - - if (type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32) { ggml_cuda_mul_mat_f32(src0, src1, dst); } - else if (type == GGML_TYPE_F16 || ggml_is_quantized(type)) { - ggml_cuda_mul_mat_q(src0, src1, dst); + else if (src0->type == GGML_TYPE_F16) { + if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) { + ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize); + } + else { + ggml_cuda_mul_mat_q_f32(src0, src1, dst); + } + } + else if (ggml_is_quantized(src0->type)) { + ggml_cuda_mul_mat_q_f32(src0, src1, dst); } else { GGML_ASSERT(false); } } + +size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) { + return ggml_nelements(src1) * sizeof(ggml_fp16_t); + } + else { + return 0; + } +} diff --git a/ggml-cuda.h b/ggml-cuda.h index 1a8efde80..f7d6a8bc1 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -7,7 +7,8 @@ extern "C" { void ggml_init_cublas(void); bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); -void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); +size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); +void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); // TODO: export these with GGML_API void * ggml_cuda_host_malloc(size_t size); diff --git a/ggml.c b/ggml.c index 77ee16da3..f1b885d6b 100644 --- a/ggml.c +++ b/ggml.c @@ -362,6 +362,32 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) { return GGML_FP32_TO_FP16(x); } +void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n) { + for (size_t i = 0; i < n; i++) { + y[i] = GGML_FP16_TO_FP32(x[i]); + } +} + +void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n) { + size_t i = 0; +#if defined(__F16C__) + for (; i + 7 < n; i += 8) { + __m256 x_vec = _mm256_loadu_ps(x + i); + __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128((__m128i *)(y + i), y_vec); + } + for(; i + 3 < n; i += 4) { + __m128 x_vec = _mm_loadu_ps(x + i); + __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storel_epi64((__m128i *)(y + i), y_vec); + } +#endif + for (; i < n; i++) { + y[i] = GGML_FP32_TO_FP16(x[i]); + } +} + + // // timing // @@ -8193,7 +8219,7 @@ static void ggml_compute_forward_mul_mat_f32( #if defined(GGML_USE_CUBLAS) if (ggml_cuda_can_mul_mat(src0, src1, dst)) { if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { - ggml_cuda_mul_mat(src0, src1, dst); + ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); } return; } @@ -8368,7 +8394,7 @@ static void ggml_compute_forward_mul_mat_f16_f32( #if defined(GGML_USE_CUBLAS) if (ggml_cuda_can_mul_mat(src0, src1, dst)) { if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { - ggml_cuda_mul_mat(src0, src1, dst); + ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); } return; } @@ -8588,7 +8614,7 @@ static void ggml_compute_forward_mul_mat_q_f32( #if defined(GGML_USE_CUBLAS) if (ggml_cuda_can_mul_mat(src0, src1, dst)) { if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { - ggml_cuda_mul_mat(src0, src1, dst); + ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); } return; } @@ -11638,6 +11664,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning + cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node); } else #endif diff --git a/ggml.h b/ggml.h index cd5765b6e..ef5a048c3 100644 --- a/ggml.h +++ b/ggml.h @@ -220,6 +220,9 @@ extern "C" { GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n); + GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n); + struct ggml_object; struct ggml_context;