From 545f23d07b4f2bc83491f21a9aec801844eb9cff Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 23 Dec 2023 15:48:46 +0100 Subject: [PATCH] refactor error checking --- ggml-cuda.cu | 94 +++++++++++++++++++++------------------------------- 1 file changed, 37 insertions(+), 57 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 532292c78..3af228f58 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -201,63 +201,43 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); -#define CUDA_CHECK(err) \ - do { \ - cudaError_t err_ = (err); \ - if (err_ != cudaSuccess) { \ - int id; \ - cudaGetDevice(&id); \ - fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - cudaGetErrorString(err_)); \ - fprintf(stderr, "current device: %d\n", id); \ - GGML_ASSERT(!"CUDA error"); \ - } \ - } while (0) - -// driver API -#define CU_CHECK(err) \ - do { \ - CUresult err_ = (err); \ - if (err_ != CUDA_SUCCESS) { \ - int id; \ - cuDeviceGet(&id, 0); \ - const char * err_str; \ - cuGetErrorString(err_, &err_str); \ - fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ - err_str); \ - fprintf(stderr, "%s\n", #err); \ - fprintf(stderr, "current device: %d\n", id); \ - GGML_ASSERT(!"CUDA error"); \ - } \ - } while (0) - - #if CUDART_VERSION >= 12000 -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - int id; \ - cudaGetDevice(&id); \ - fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ - err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ - fprintf(stderr, "current device: %d\n", id); \ - GGML_ASSERT(!"cuBLAS error"); \ - } \ - } while (0) + static const char * cublas_get_error_str(const cublasStatus_t err) { + return cublasGetStatusString(err); + } #else -#define CUBLAS_CHECK(err) \ - do { \ - cublasStatus_t err_ = (err); \ - if (err_ != CUBLAS_STATUS_SUCCESS) { \ - int id; \ - cudaGetDevice(&id); \ - fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ - fprintf(stderr, "current device: %d\n", id); \ - GGML_ASSERT(!"cuBLAS error"); \ - } \ - } while (0) -#endif // CUDART_VERSION >= 11 + static const char * cublas_get_error_str(const cublasStatus_t err) { + switch (err) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + default: return "unknown error"; + } +#endif // CUDART_VERSION >= 12000 + +static const char * cu_get_error_str(CUresult err) { + const char * err_str; + cuGetErrorString(err, &err_str); + return err_str; +} + +[[noreturn]] +static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) { + fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg); + fprintf(stderr, " in function %s at %s:%d\n", func, file, line); + GGML_ASSERT(!"CUDA error"); +} + +#define CUDA_CHECK(err) do { auto err_ = (err); if (err_ != cudaSuccess) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); } while (0) +#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0) +#define CU_CHECK(err) do { auto err_ = (err); if (err_ != CUDA_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_)); } while (0) #if CUDART_VERSION >= 11100 #define GGML_CUDA_ASSUME(x) __builtin_assume(x) @@ -537,13 +517,13 @@ static int g_device_count = -1; static int g_main_device = 0; static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; -struct device_capabilities { +struct cuda_device_capabilities { int cc; // compute capability bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory }; -static device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} }; +static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} }; static void * g_scratch_buffer = nullptr;