From f06ddb2e7963a80abc524178b874bbdedd9dd821 Mon Sep 17 00:00:00 2001 From: Vlad Date: Tue, 12 Sep 2023 01:42:42 +0300 Subject: [PATCH] CUDA 11.0 fixes --- Makefile | 7 +++- ggml-cuda.cu | 92 ++++++++++++++++++++++++++++------------------------ 2 files changed, 55 insertions(+), 44 deletions(-) diff --git a/Makefile b/Makefile index a774dc50f..593af77df 100644 --- a/Makefile +++ b/Makefile @@ -373,7 +373,7 @@ ifdef LLAMA_CUDA_CCBIN NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) endif ggml-cuda.o: ggml-cuda.cu ggml-cuda.h - $(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) -Wno-pedantic -c $< -o $@ + $(NVCC) $(NVCCFLAGS) $(CXXFLAGS_CUDA) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS ifdef LLAMA_CLBLAST @@ -446,6 +446,11 @@ override CFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CFLAGS) $(CFLAGS) override CXXFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CXXFLAGS) $(CXXFLAGS) override LDFLAGS := $(MK_LDFLAGS) $(LDFLAGS) +COMMA := , +CXXFLAGS_CUDA := $(CXXFLAGS) +CXXFLAGS_CUDA := $(subst -march=native -mtune=native,--compiler-options=-march=native$(COMMA)-mtune=native,$(CXXFLAGS_CUDA)) +CXXFLAGS_CUDA := $(subst -Ofast,-O3,$(CXXFLAGS_CUDA)) + # # Print build information # diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a14e2362a..d021ac879 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -61,7 +61,7 @@ #define cudaStreamCreateWithFlags hipStreamCreateWithFlags #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0) +#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #else @@ -180,6 +180,12 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } while (0) #endif // CUDART_VERSION >= 11 +#if CUDART_VERSION >= 11100 +#define GGML_ASSUME(x) __builtin_assume(x) +#else +#define GGML_ASSUME(x) +#endif // CUDART_VERSION >= 11100 + #ifdef GGML_CUDA_F16 typedef half dfloat; // dequantize float typedef half2 dfloat2; @@ -2135,10 +2141,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI4_0; const int kqsx = k % QI4_0; @@ -2229,10 +2235,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI4_1; const int kqsx = k % QI4_1; @@ -2321,10 +2327,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI5_0; const int kqsx = k % QI5_0; @@ -2435,10 +2441,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI5_1; const int kqsx = k % QI5_1; @@ -2541,10 +2547,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI8_0; const int kqsx = k % QI8_0; @@ -2632,10 +2638,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI2_K; const int kqsx = k % QI2_K; @@ -2753,10 +2759,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI3_K; const int kqsx = k % QI3_K; @@ -2971,10 +2977,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI4_K; // == 0 if QK_K == 256 const int kqsx = k % QI4_K; // == k if QK_K == 256 @@ -3152,10 +3158,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI5_K; // == 0 if QK_K == 256 const int kqsx = k % QI5_K; // == k if QK_K == 256 @@ -3281,10 +3287,10 @@ template static __device__ __forceinlin const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { - __builtin_assume(i_offset >= 0); - __builtin_assume(i_offset < nwarps); - __builtin_assume(k >= 0); - __builtin_assume(k < WARP_SIZE); + GGML_ASSUME(i_offset >= 0); + GGML_ASSUME(i_offset < nwarps); + GGML_ASSUME(k >= 0); + GGML_ASSUME(k < WARP_SIZE); const int kbx = k / QI6_K; // == 0 if QK_K == 256 const int kqsx = k % QI6_K; // == k if QK_K == 256 @@ -6021,7 +6027,7 @@ static void ggml_cuda_op_mul_mat( // wait for main GPU data if necessary if (split && (id != g_main_device || is != 0)) { - CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0])); + CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[g_main_device][0], 0)); } for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) { @@ -6143,7 +6149,7 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_set_device(g_main_device)); for (int64_t id = 0; id < g_device_count; ++id) { for (int64_t is = 0; is < is_max; ++is) { - CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is])); + CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams[g_main_device][0], src0_extra->events[id][is], 0)); } } }