refactor error checking
This commit is contained in:
parent
4c0f300a2c
commit
545f23d07b
1 changed files with 37 additions and 57 deletions
94
ggml-cuda.cu
94
ggml-cuda.cu
|
@ -201,63 +201,43 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
|||
|
||||
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
|
||||
#define CUDA_CHECK(err) \
|
||||
do { \
|
||||
cudaError_t err_ = (err); \
|
||||
if (err_ != cudaSuccess) { \
|
||||
int id; \
|
||||
cudaGetDevice(&id); \
|
||||
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||
cudaGetErrorString(err_)); \
|
||||
fprintf(stderr, "current device: %d\n", id); \
|
||||
GGML_ASSERT(!"CUDA error"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// driver API
|
||||
#define CU_CHECK(err) \
|
||||
do { \
|
||||
CUresult err_ = (err); \
|
||||
if (err_ != CUDA_SUCCESS) { \
|
||||
int id; \
|
||||
cuDeviceGet(&id, 0); \
|
||||
const char * err_str; \
|
||||
cuGetErrorString(err_, &err_str); \
|
||||
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||
err_str); \
|
||||
fprintf(stderr, "%s\n", #err); \
|
||||
fprintf(stderr, "current device: %d\n", id); \
|
||||
GGML_ASSERT(!"CUDA error"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
#define CUBLAS_CHECK(err) \
|
||||
do { \
|
||||
cublasStatus_t err_ = (err); \
|
||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||
int id; \
|
||||
cudaGetDevice(&id); \
|
||||
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
|
||||
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
|
||||
fprintf(stderr, "current device: %d\n", id); \
|
||||
GGML_ASSERT(!"cuBLAS error"); \
|
||||
} \
|
||||
} while (0)
|
||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
return cublasGetStatusString(err);
|
||||
}
|
||||
#else
|
||||
#define CUBLAS_CHECK(err) \
|
||||
do { \
|
||||
cublasStatus_t err_ = (err); \
|
||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||
int id; \
|
||||
cudaGetDevice(&id); \
|
||||
fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
||||
fprintf(stderr, "current device: %d\n", id); \
|
||||
GGML_ASSERT(!"cuBLAS error"); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif // CUDART_VERSION >= 11
|
||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
switch (err) {
|
||||
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
|
||||
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
|
||||
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
default: return "unknown error";
|
||||
}
|
||||
#endif // CUDART_VERSION >= 12000
|
||||
|
||||
static const char * cu_get_error_str(CUresult err) {
|
||||
const char * err_str;
|
||||
cuGetErrorString(err, &err_str);
|
||||
return err_str;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
|
||||
fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg);
|
||||
fprintf(stderr, " in function %s at %s:%d\n", func, file, line);
|
||||
GGML_ASSERT(!"CUDA error");
|
||||
}
|
||||
|
||||
#define CUDA_CHECK(err) do { auto err_ = (err); if (err_ != cudaSuccess) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); } while (0)
|
||||
#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0)
|
||||
#define CU_CHECK(err) do { auto err_ = (err); if (err_ != CUDA_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_)); } while (0)
|
||||
|
||||
#if CUDART_VERSION >= 11100
|
||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||
|
@ -537,13 +517,13 @@ static int g_device_count = -1;
|
|||
static int g_main_device = 0;
|
||||
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
|
||||
struct device_capabilities {
|
||||
struct cuda_device_capabilities {
|
||||
int cc; // compute capability
|
||||
bool vmm; // virtual memory support
|
||||
size_t vmm_granularity; // granularity of virtual memory
|
||||
};
|
||||
|
||||
static device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
|
||||
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
|
||||
|
||||
|
||||
static void * g_scratch_buffer = nullptr;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue