diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e9677943a..007b89359 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -86,6 +86,15 @@ #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define __trap abort +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED +#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED +#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE +#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH +#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR +#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED +#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR +#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED #else #include #include @@ -218,18 +227,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; - case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; default: return "unknown error"; } } #endif // CUDART_VERSION >= 12000 -static const char * cu_get_error_str(CUresult err) { - const char * err_str; - cuGetErrorString(err, &err_str); - return err_str; -} - [[noreturn]] static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) { fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg); @@ -239,7 +241,15 @@ static void ggml_cuda_error(const char * stmt, const char * func, const char * f #define CUDA_CHECK(err) do { auto err_ = (err); if (err_ != cudaSuccess) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); } while (0) #define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0) + +#if !defined(GGML_USE_HIPBLAS) +static const char * cu_get_error_str(CUresult err) { + const char * err_str; + cuGetErrorString(err, &err_str); + return err_str; +} #define CU_CHECK(err) do { auto err_ = (err); if (err_ != CUDA_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_)); } while (0) +#endif #if CUDART_VERSION >= 11100 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)