better error checking

This commit is contained in:
slaren 2023-12-24 19:04:36 +01:00
parent a76cadad48
commit 6f35a4a6e9

View file

@ -212,6 +212,28 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
[[noreturn]]
static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
int id = -1; // in case cudaGetDevice fails
cudaGetDevice(&id);
fprintf(stderr, "CUDA error: %s\n", msg);
fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line);
fprintf(stderr, " %s\n", stmt);
// abort with GGML_ASSERT to get a stack trace
GGML_ASSERT(!"CUDA error");
}
#define CUDA_CHECK_GEN(err, success, error_fn) \
do { \
auto err_ = (err); \
if (err_ != (success)) { \
ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
} \
} while (0)
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
#if CUDART_VERSION >= 12000 #if CUDART_VERSION >= 12000
static const char * cublas_get_error_str(const cublasStatus_t err) { static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err); return cublasGetStatusString(err);
@ -233,15 +255,8 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} }
#endif // CUDART_VERSION >= 12000 #endif // CUDART_VERSION >= 12000
[[noreturn]] #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg);
fprintf(stderr, " in function %s at %s:%d\n", func, file, line);
GGML_ASSERT(!"CUDA error");
}
#define CUDA_CHECK(err) do { auto err_ = (err); if (err_ != cudaSuccess) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); } while (0)
#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0)
#if !defined(GGML_USE_HIPBLAS) #if !defined(GGML_USE_HIPBLAS)
static const char * cu_get_error_str(CUresult err) { static const char * cu_get_error_str(CUresult err) {
@ -249,7 +264,7 @@ static const char * cu_get_error_str(CUresult err) {
cuGetErrorString(err, &err_str); cuGetErrorString(err, &err_str);
return err_str; return err_str;
} }
#define CU_CHECK(err) do { auto err_ = (err); if (err_ != CUDA_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_)); } while (0) #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif #endif
#if CUDART_VERSION >= 11100 #if CUDART_VERSION >= 11100
@ -538,7 +553,6 @@ struct cuda_device_capabilities {
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} }; static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
static void * g_scratch_buffer = nullptr; static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 0; // disabled by default static size_t g_scratch_size = 0; // disabled by default
static size_t g_scratch_offset = 0; static size_t g_scratch_offset = 0;
@ -4727,7 +4741,6 @@ static __global__ void mul_mat_p021_f16_f32(
const int row_y = col_x; const int row_y = col_x;
// y is not transposed but permuted // y is not transposed but permuted
const int iy = channel*nrows_y + row_y; const int iy = channel*nrows_y + row_y;
@ -7209,7 +7222,6 @@ inline void ggml_cuda_op_norm(
(void) src1_dd; (void) src1_dd;
} }
inline void ggml_cuda_op_group_norm( inline void ggml_cuda_op_group_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@ -7784,7 +7796,6 @@ inline void ggml_cuda_op_im2col(
(void) src0_dd; (void) src0_dd;
} }
inline void ggml_cuda_op_sum_rows( inline void ggml_cuda_op_sum_rows(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {