diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 9da4e8008..bc122c9b8 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -14,9 +14,11 @@ // for rocblas_initialize() #include "rocblas/rocblas.h" #endif // __HIP_PLATFORM_AMD__ +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N #define CUBLAS_OP_T HIPBLAS_OP_T #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS