use hipblas based on cublas
This commit is contained in:
parent
2005469ea1
commit
0fd8363adc
4 changed files with 69 additions and 2 deletions
|
@ -67,6 +67,7 @@ endif()
|
|||
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
|
||||
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
|
||||
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
||||
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
|
||||
|
||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
|
@ -168,6 +169,31 @@ if (LLAMA_CUBLAS)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_HIPBLAS)
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
|
||||
find_package(hip)
|
||||
find_package(hipblas)
|
||||
|
||||
if (hipblas_FOUND)
|
||||
message(STATUS "hipBLAS found")
|
||||
|
||||
set(LLAMA_HIPBLAS_PLATFORM "AMD" CACHE STRING "hip device type" FORCE)
|
||||
set_property(CACHE LLAMA_HIPBLAS_PLATFORM PROPERTY STRINGS "AMD" "NVIDIA")
|
||||
|
||||
add_compile_definitions(GGML_USE_HIPBLAS "__HIP_PLATFORM_${LLAMA_HIPBLAS_PLATFORM}__")
|
||||
|
||||
add_library(ggml-hip OBJECT ggml-cuda.cu)
|
||||
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
|
||||
target_link_libraries(ggml-hip hip::device)
|
||||
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} hip::host roc::hipblas ggml-hip)
|
||||
|
||||
else()
|
||||
message(WARNING "hipBLAS not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_ALL_WARNINGS)
|
||||
if (NOT MSVC)
|
||||
set(c_flags
|
||||
|
|
4
Makefile
4
Makefile
|
@ -107,6 +107,10 @@ ifdef LLAMA_CUBLAS
|
|||
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
||||
nvcc -arch=native -c -o $@ $<
|
||||
endif
|
||||
ifdef LLAMA_HIPBLAS
|
||||
CFLAGS += -DGGML_USE_HIPBLAS -D__HIP_PLATFORM_AMD__ -I/opt/rocm/include
|
||||
LDFLAGS += -lhipblas -lamdhip64 -L/opt/rocm/lib
|
||||
endif
|
||||
ifdef LLAMA_GPROF
|
||||
CFLAGS += -pg
|
||||
CXXFLAGS += -pg
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
#include <stdint.h>
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#include "hip/hip_runtime.h"
|
||||
#define cudaStream_t hipStream_t
|
||||
#define __half _Float16
|
||||
#else
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
#include "ggml-cuda.h"
|
||||
|
||||
typedef uint16_t ggml_fp16_t;
|
||||
|
|
35
ggml.c
35
ggml.c
|
@ -147,9 +147,41 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
|||
#include <Accelerate/Accelerate.h>
|
||||
#elif defined(GGML_USE_OPENBLAS)
|
||||
#include <cblas.h>
|
||||
#elif defined(GGML_USE_CUBLAS)
|
||||
#elif defined(GGML_USE_CUBLAS) || defined(GGML_USE_HIPBLAS)
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#include "hipblas/hipblas.h"
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
#define CUBLAS_OP_T HIPBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
#define cublasHandle_t hipblasHandle_t
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define CUDA_R_16F HIPBLAS_R_16F
|
||||
#define CUDA_R_32F HIPBLAS_R_32F
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaFree hipFree
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaGetLastError hipGetLastError
|
||||
#define cudaMalloc hipMalloc
|
||||
#define cudaMemcpyAsync hipMemcpyAsync
|
||||
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
||||
#define cudaStreamNonBlocking hipStreamNonBlocking
|
||||
#define cudaStreamSynchronize hipStreamSynchronize
|
||||
#define cudaSuccess hipSuccess
|
||||
#define GGML_USE_CUBLAS
|
||||
#else
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
#include "ggml-cuda.h"
|
||||
|
||||
#define CUDA_CHECK(err) \
|
||||
|
@ -8073,7 +8105,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
const float * x = wdata;
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// copy data to device
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue