From 04c0d480d780b7e43f9cd5726b1c1d66570b57d8 Mon Sep 17 00:00:00 2001 From: Henri Vasserman Date: Thu, 4 May 2023 12:31:16 +0300 Subject: [PATCH] Move all HIP stuff to ggml-cuda.cu --- CMakeLists.txt | 10 +++++----- ggml-cuda.cu | 44 +++++++++++++++++++++++++++++++++++++++++--- ggml-cuda.h | 46 ---------------------------------------------- 3 files changed, 46 insertions(+), 54 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e01bb2edd..79393a54e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -232,16 +232,16 @@ if (LLAMA_HIPBLAS) find_package(hipblas) if (${hipblas_FOUND} AND ${hip_FOUND}) - message(STATUS "hipBLAS found") - add_compile_definitions(GGML_USE_HIPBLAS) - add_library(ggml-hip OBJECT ggml-cuda.cu ggml-cuda.h) + message(STATUS "HIP and hipBLAS found") + add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) + add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) - target_link_libraries(ggml-hip PRIVATE hip::device) + target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::hipblas) if (LLAMA_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} hip::host roc::hipblas ggml-hip) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm) else() message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") endif() diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 033c7d5c8..9007f6dcb 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5,9 +5,47 @@ #include #if defined(GGML_USE_HIPBLAS) -#include "hip/hip_runtime.h" -#include "hipblas/hipblas.h" -#include "hip/hip_fp16.h" +#include +#include +#include +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F +#define cublasCreate hipblasCreate +#define cublasGemmEx hipblasGemmEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEvent_t hipEvent_t +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocPortable) +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent hipStreamWaitEvent +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess #else #include #include diff --git a/ggml-cuda.h b/ggml-cuda.h index 0e740e309..f7d6a8bc1 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -1,49 +1,3 @@ -#if defined(GGML_USE_HIPBLAS) -#include "hipblas/hipblas.h" -#include "hip/hip_runtime.h" -#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F -#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F -#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT -#define CUBLAS_OP_N HIPBLAS_OP_N -#define CUBLAS_OP_T HIPBLAS_OP_T -#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS -#define CUBLAS_TF32_TENSOR_OP_MATH 0 -#define CUDA_R_16F HIPBLAS_R_16F -#define CUDA_R_32F HIPBLAS_R_32F -#define cublasCreate hipblasCreate -#define cublasGemmEx hipblasGemmEx -#define cublasHandle_t hipblasHandle_t -#define cublasSetMathMode(h, m) HIPBLAS_STATUS_SUCCESS -#define cublasSetStream hipblasSetStream -#define cublasSgemm hipblasSgemm -#define cublasStatus_t hipblasStatus_t -#define cudaDeviceSynchronize hipDeviceSynchronize -#define cudaError_t hipError_t -#define cudaEventCreateWithFlags hipEventCreateWithFlags -#define cudaEventDisableTiming hipEventDisableTiming -#define cudaEventRecord hipEventRecord -#define cudaEvent_t hipEvent_t -#define cudaFree hipFree -#define cudaFreeHost hipFreeHost -#define cudaGetErrorString hipGetErrorString -#define cudaGetLastError hipGetLastError -#define cudaMalloc hipMalloc -#define cudaMallocHost hipMallocHost -#define cudaMemcpy2DAsync hipMemcpy2DAsync -#define cudaMemcpyAsync hipMemcpyAsync -#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost -#define cudaMemcpyHostToDevice hipMemcpyHostToDevice -#define cudaStreamCreateWithFlags hipStreamCreateWithFlags -#define cudaStreamNonBlocking hipStreamNonBlocking -#define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamWaitEvent hipStreamWaitEvent -#define cudaStream_t hipStream_t -#define cudaSuccess hipSuccess -#define GGML_USE_CUBLAS -#else -#include -#include -#endif #include "ggml.h" #ifdef __cplusplus