From 2a86c00ffa7282a0f1f968aa0f170a2e5f4a9044 Mon Sep 17 00:00:00 2001 From: Quinten Kock Date: Sat, 16 Dec 2023 17:40:56 +0100 Subject: [PATCH] Add basic chipStar support --- Makefile | 21 ++++ ggml-cuda.cu | 323 ++++++++++++++++++++++++++------------------------- 2 files changed, 184 insertions(+), 160 deletions(-) diff --git a/Makefile b/Makefile index fb775ae5b..87355f293 100644 --- a/Makefile +++ b/Makefile @@ -460,6 +460,27 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< endif # LLAMA_HIPBLAS +ifdef LLAMA_CHIPSTAR + CUSPVC ?= cuspvc + LLAMA_CUDA_DMMV_X ?= 32 + LLAMA_CUDA_MMV_Y ?= 1 + LLAMA_CUDA_KQUANTS_ITER ?= 2 + MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS -DGGML_USE_CHIPSTAR + MK_CPPFLAGS += -I/opt/H4IBLAS/include + MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib + # MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas + MK_LDFLAGS += -L/opt/chipstar/lib -lCHIP -L/opt/H4IBLAS/lib -lhipblas -L/opt/H4IMKL/lib -lMKLShim + # HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) + HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) + HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) + HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) + HIPFLAGS += -DGGML_CUDA_FORCE_DMMV + OBJS += ggml-cuda.o +ggml-cuda.o: ggml-cuda.cu ggml-cuda.h + $(CUSPVC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< +endif # LLAMA_HIPBLAS + + ifdef LLAMA_METAL MK_CPPFLAGS += -DGGML_USE_METAL MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0a63c1ecf..86a39d96d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -13,7 +13,7 @@ #if defined(GGML_USE_HIPBLAS) #include -#include +#include #include #ifdef __HIP_PLATFORM_AMD__ // for rocblas_initialize() @@ -106,7 +106,7 @@ // TODO: improve this to be correct for more hardware // for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores // probably other such cases, and not sure what happens on AMD hardware -#if !defined(GGML_CUDA_FORCE_MMQ) +#if !defined(GGML_CUDA_FORCE_MMQ) && !defined(GGML_USE_CHIPSTAR) #define CUDA_USE_TENSOR_CORES #endif @@ -5499,23 +5499,23 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); - if (ncols < 1024) { + if (ncols < 256) { const dim3 block_dims(WARP_SIZE, 1, 1); norm_f32<<>>(x, dst, ncols, eps); } else { - const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols, eps); + const dim3 block_dims(256, 1, 1); + norm_f32<256><<>>(x, dst, ncols, eps); } } static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) { static const float eps = 1e-6f; - if (group_size < 1024) { + if (group_size < 256) { const dim3 block_dims(WARP_SIZE, 1, 1); group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); } else { - const dim3 block_dims(1024, 1, 1); - group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps); + const dim3 block_dims(256, 1, 1); + group_norm_f32<256><<>>(x, dst, group_size, ne_elements, eps); } } @@ -5542,12 +5542,12 @@ static void pad_f32_cuda(const float * x, float * dst, static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); - if (ncols < 1024) { + if (ncols < 256) { const dim3 block_dims(WARP_SIZE, 1, 1); rms_norm_f32<<>>(x, dst, ncols, eps); } else { - const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024><<>>(x, dst, ncols, eps); + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256><<>>(x, dst, ncols, eps); } } @@ -7376,6 +7376,7 @@ inline void ggml_cuda_op_mul_mat_cublas( const int compute_capability = g_compute_capabilities[id]; +#ifndef GGML_USE_CHIPSTAR if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 half * src0_as_f16 = nullptr; @@ -7428,7 +7429,9 @@ inline void ggml_cuda_op_mul_mat_cublas( ggml_cuda_pool_free(src1_as_f16, src1_as); } } - else { + else +#endif + { float * src0_ddq_as_f32 = nullptr; size_t src0_as = 0; @@ -8323,153 +8326,153 @@ static __global__ void k_compute_batched_ptrs( ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; } -static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - - GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const int64_t nb01 = src0->nb[1]; - const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); - const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src1->ne[3]; - - const int64_t nb11 = src1->nb[1]; - const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); - const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); - - const int64_t ne1 = ggml_nelements(src1); - const int64_t ne = ggml_nelements(dst); - - CUDA_CHECK(ggml_cuda_set_device(g_main_device)); - cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); - - ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - void * src0_ddq = src0_extra->data_device[g_main_device]; - half * src0_as_f16 = (half *) src0_ddq; - - ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; - float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; - - ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; - float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - - // convert src1 to fp16 - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - - size_t src1_as = 0; - half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); - to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); - - size_t dst_as = 0; - half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); - - GGML_ASSERT(ne12 % ne02 == 0); - GGML_ASSERT(ne13 % ne03 == 0); - - // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; - - const half alpha_f16 = 1.0f; - const half beta_f16 = 0.0f; - -#if 0 - // use cublasGemmEx - { - for (int i13 = 0; i13 < ne13; ++i13) { - for (int i12 = 0; i12 < ne12; ++i12) { - int i03 = i13 / r3; - int i02 = i12 / r2; - - CUBLAS_CHECK( - cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), - &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } - } - } -#else - if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { - // there is no broadcast and src0, src1 are contiguous across dims 2, 3 - // use cublasGemmStridedBatchedEx - CUBLAS_CHECK( - cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA - (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB - &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC - ne12*ne13, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - } else { - // use cublasGemmBatchedEx - const int ne23 = ne12*ne13; - - const void ** ptrs_src = nullptr; - void ** ptrs_dst = nullptr; - - size_t ptrs_src_s = 0; - size_t ptrs_dst_s = 0; - - ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); - ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); - - dim3 block_dims(ne13, ne12); - k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( - src0_as_f16, src1_as_f16, dst_f16, - ptrs_src, ptrs_dst, - ne12, ne13, - ne23, - nb02, nb03, - nb12, nb13, - dst->nb[2], dst->nb[3], - r2, r3); - CUDA_CHECK(cudaGetLastError()); - - CUBLAS_CHECK( - cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half), - (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float), - &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, - ne23, - CUBLAS_COMPUTE_16F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - if (ptrs_src_s != 0) { - ggml_cuda_pool_free(ptrs_src, ptrs_src_s); - } - if (ptrs_dst_s != 0) { - ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); - } - } -#endif - - const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); - to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); - - ggml_cuda_pool_free(src1_as_f16, src1_as); - ggml_cuda_pool_free(dst_f16, dst_as); -} +// static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +// GGML_ASSERT(!ggml_is_transposed(src0)); +// GGML_ASSERT(!ggml_is_transposed(src1)); +// +// GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); +// GGML_ASSERT(src0->type == GGML_TYPE_F16); +// GGML_ASSERT(src1->type == GGML_TYPE_F32); +// +// const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); +// const int64_t ne01 = src0->ne[1]; +// const int64_t ne02 = src0->ne[2]; +// const int64_t ne03 = src0->ne[3]; +// +// const int64_t nb01 = src0->nb[1]; +// const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); +// const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); +// +// const int64_t ne10 = src1->ne[0]; +// const int64_t ne11 = src1->ne[1]; +// const int64_t ne12 = src1->ne[2]; +// const int64_t ne13 = src1->ne[3]; +// +// const int64_t nb11 = src1->nb[1]; +// const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); +// const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); +// +// const int64_t ne1 = ggml_nelements(src1); +// const int64_t ne = ggml_nelements(dst); +// +// CUDA_CHECK(ggml_cuda_set_device(g_main_device)); +// cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; +// +// CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); +// +// ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; +// void * src0_ddq = src0_extra->data_device[g_main_device]; +// half * src0_as_f16 = (half *) src0_ddq; +// +// ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; +// float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; +// +// ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; +// float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; +// +// // convert src1 to fp16 +// const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); +// GGML_ASSERT(to_fp16_cuda != nullptr); +// +// size_t src1_as = 0; +// half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); +// to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); +// +// size_t dst_as = 0; +// half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); +// +// GGML_ASSERT(ne12 % ne02 == 0); +// GGML_ASSERT(ne13 % ne03 == 0); +// +// // broadcast factors +// const int64_t r2 = ne12/ne02; +// const int64_t r3 = ne13/ne03; +// +// const half alpha_f16 = 1.0f; +// const half beta_f16 = 0.0f; +// +// #if 0 +// // use cublasGemmEx +// { +// for (int i13 = 0; i13 < ne13; ++i13) { +// for (int i12 = 0; i12 < ne12; ++i12) { +// int i03 = i13 / r3; +// int i02 = i12 / r2; +// +// CUBLAS_CHECK( +// cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, +// ne01, ne11, ne10, +// &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), +// (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), +// &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, +// CUBLAS_COMPUTE_16F, +// CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +// } +// } +// } +// #else +// if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { +// // there is no broadcast and src0, src1 are contiguous across dims 2, 3 +// // use cublasGemmStridedBatchedEx +// CUBLAS_CHECK( +// cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, +// ne01, ne11, ne10, +// &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA +// (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB +// &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC +// ne12*ne13, +// CUBLAS_COMPUTE_16F, +// CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +// } else { +// // use cublasGemmBatchedEx +// const int ne23 = ne12*ne13; +// +// const void ** ptrs_src = nullptr; +// void ** ptrs_dst = nullptr; +// +// size_t ptrs_src_s = 0; +// size_t ptrs_dst_s = 0; +// +// ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); +// ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); +// +// dim3 block_dims(ne13, ne12); +// k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( +// src0_as_f16, src1_as_f16, dst_f16, +// ptrs_src, ptrs_dst, +// ne12, ne13, +// ne23, +// nb02, nb03, +// nb12, nb13, +// dst->nb[2], dst->nb[3], +// r2, r3); +// CUDA_CHECK(cudaGetLastError()); +// +// CUBLAS_CHECK( +// cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, +// ne01, ne11, ne10, +// &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half), +// (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float), +// &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, +// ne23, +// CUBLAS_COMPUTE_16F, +// CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +// +// if (ptrs_src_s != 0) { +// ggml_cuda_pool_free(ptrs_src, ptrs_src_s); +// } +// if (ptrs_dst_s != 0) { +// ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); +// } +// } +// #endif +// +// const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); +// to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); +// +// ggml_cuda_pool_free(src1_as_f16, src1_as); +// ggml_cuda_pool_free(dst_f16, dst_as); +// } static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool all_on_device = @@ -8508,7 +8511,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 ggml_cuda_mul_mat_vec_nc(src0, src1, dst); } else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) { // KQ + KQV multi-batch - ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); + // ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {