Add musa support

This commit is contained in:
dixyes 2024-04-15 19:05:35 +08:00
parent 5c4d767ac0
commit b7499e0460
5 changed files with 282 additions and 12 deletions

View file

@ -565,6 +565,38 @@ ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/com
endif # LLAMA_HIPBLAS
ifdef LLAMA_MUSA
MUSA_PATH ?= /usr/local/musa
MUSA_ARCH ?= 10
MCC ?= $(CCACHE) $(MUSA_PATH)/bin/mcc
LLAMA_CUDA_DMMV_X ?= 32
LLAMA_CUDA_MMV_Y ?= 1
LLAMA_CUDA_KQUANTS_ITER ?= 2
MK_CPPFLAGS += -DGGML_USE_MUSA -DGGML_USE_CUDA
MK_LDFLAGS += -L$(MUSA_PATH)/lib -Wl,-rpath=$(MUSA_PATH)/lib
MK_LDFLAGS += -lmublas -lmusa -lmusart
MUSAFLAGS += --cuda-gpu-arch=mp_$(MUSA_ARCH)
MUSAFLAGS += -Wno-unknown-warning-option -Wno-gnu-anonymous-struct -Wno-nested-anon-types -Wno-invalid-noreturn
MUSAFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
MUSAFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
MUSAFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
ifdef LLAMA_CUDA_FORCE_DMMV
MUSAFLAGS += -DGGML_CUDA_FORCE_DMMV
endif # LLAMA_CUDA_FORCE_DMMV
ifdef LLAMA_CUDA_NO_PEER_COPY
MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY
endif # LLAMA_CUDA_NO_PEER_COPY
OBJS += ggml-cuda.o
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -c -o $@ $<
ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -c -o $@ $<
endif # LLAMA_MUSA
ifdef LLAMA_METAL
MK_CPPFLAGS += -DGGML_USE_METAL
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit

View file

@ -17,6 +17,15 @@ typedef half2 ggml_half2;
#define GGML_COMMON_AGGR
#define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_MUSA)
#include <mublas.h>
typedef half ggml_half;
typedef half2 ggml_half2;
#define GGML_COMMON_AGGR
#define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_CUDA)
#include <cuda_fp16.h>
@ -73,7 +82,7 @@ typedef sycl::half2 ggml_half2;
#define K_SCALE_SIZE 12
#endif // GGML_QKK_64
#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL)
#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL) || defined(GGML_COMMON_DECL_MUSA)
// QR = QK / number of values before dequantization
// QI = number of 32 bit integers before dequantization
@ -439,7 +448,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
#define GGML_TABLE_END() };
#define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
#include <cstdint>
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {

View file

@ -112,7 +112,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0;
#if !defined(GGML_USE_HIPBLAS)
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
CUdevice device;
CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@ -124,7 +124,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
alloc_prop.location.id = id;
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
}
#endif // !defined(GGML_USE_HIPBLAS)
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
info.devices[id].vmm = !!device_vmm;
cudaDeviceProp prop;
@ -257,7 +257,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
};
// pool with virtual memory
#if !defined(GGML_USE_HIPBLAS)
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
@ -351,10 +351,10 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
}
};
#endif // !defined(GGML_USE_HIPBLAS)
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
#if !defined(GGML_USE_HIPBLAS)
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
if (ggml_cuda_info().devices[device].vmm) {
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
}
@ -1596,7 +1596,7 @@ static void ggml_cuda_op_mul_mat(
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
#if !defined(GGML_USE_HIPBLAS)
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
cudaMemcpy3DPeerParms p = {};
p.dstDevice = ctx.device;
@ -1793,7 +1793,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
#if 0
#if defined(GGML_USE_MUSA)
// use cublasGemmEx
{
for (int i13 = 0; i13 < ne13; ++i13) {
@ -1802,10 +1802,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
int i02 = i12 / r2;
CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
alpha, (const char *) src0_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
(const char *) src1_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

View file

@ -8,6 +8,9 @@
#if defined(GGML_USE_HIPBLAS)
#define GGML_COMMON_DECL_HIP
#define GGML_COMMON_IMPL_HIP
#elif defined(GGML_USE_MUSA)
#define GGML_COMMON_DECL_MUSA
#define GGML_COMMON_IMPL_MUSA
#else
#define GGML_COMMON_DECL_CUDA
#define GGML_COMMON_IMPL_CUDA
@ -117,6 +120,10 @@
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#elif defined(GGML_USE_MUSA)
#include <musa.h>
#include <mublas.h>
#include "musa_compatible.cuh"
#else
#include <cuda_runtime.h>
#include <cuda.h>
@ -189,6 +196,26 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err);
}
#elif defined(GGML_USE_MUSA)
static const char * cublas_get_error_str(const cublasStatus_t err) {
switch (err) {
case MUBLAS_STATUS_SUCCESS: return "MUBLAS_STATUS_SUCCESS";
case MUBLAS_STATUS_INVALID_HANDLE: return "MUBLAS_STATUS_INVALID_HANDLE";
case MUBLAS_STATUS_NOT_IMPLEMENTED: return "MUBLAS_STATUS_NOT_IMPLEMENTED";
case MUBLAS_STATUS_INVALID_POINTER: return "MUBLAS_STATUS_INVALID_POINTER";
case MUBLAS_STATUS_INVALID_SIZE: return "MUBLAS_STATUS_INVALID_SIZE";
case MUBLAS_STATUS_MEMORY_ERROR: return "MUBLAS_STATUS_MEMORY_ERROR";
case MUBLAS_STATUS_INTERNAL_ERROR: return "MUBLAS_STATUS_INTERNAL_ERROR";
case MUBLAS_STATUS_PERF_DEGRADED: return "MUBLAS_STATUS_PERF_DEGRADED";
case MUBLAS_STATUS_SIZE_QUERY_MISMATCH: return "MUBLAS_STATUS_SIZE_QUERY_MISMATCH";
case MUBLAS_STATUS_SIZE_INCREASED: return "MUBLAS_STATUS_SIZE_INCREASED";
case MUBLAS_STATUS_SIZE_UNCHANGED: return "MUBLAS_STATUS_SIZE_UNCHANGED";
case MUBLAS_STATUS_INVALID_VALUE: return "MUBLAS_STATUS_INVALID_VALUE";
case MUBLAS_STATUS_CONTINUE: return "MUBLAS_STATUS_CONTINUE";
default: return "unknown error";
}
}
#else
static const char * cublas_get_error_str(const cublasStatus_t err) {
switch (err) {

View file

@ -0,0 +1,202 @@
#ifndef _MUSA_COMPATIBLE_CUH
#define _MUSA_COMPATIBLE_CUH
#define CUresult MUresult
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define cudaDataType_t musaDataType_t
#define cudaError_t musaError_t
#define cudaEvent_t musaEvent_t
#define cudaStream_t musaStream_t
#define cudaDeviceProp musaDeviceProp
#define cublasStatus_t mublasStatus_t
#define cublasHandle_t mublasHandle_t
#define cublasComputeType_t musaDataType_t // reserved in musa
#define cuGetErrorString muGetErrorString
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
// #define cuMemGetAllocationGranularity muMemGetAllocationGranularity // so far, not implemeted
// #define CUmemAllocationProp MUmemAllocationProp
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaMemGetInfo musaMemGetInfo
#define cudaMemset musaMemset
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaHostUnregister musaHostUnregister
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaSetDevice musaSetDevice
#define cudaEventRecord musaEventRecord
#define cudaEventDestroy musaEventDestroy
#define cudaEventCreate musaEventCreate
#define cudaEventSynchronize musaEventSynchronize
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasGemmEx mublasGemmEx
#define cublasSgemm mublasSgemm
#ifdef mublasGemmStridedBatchedEx
#undef mublasGemmStridedBatchedEx
#endif // mublasGemmStridedBatchedEx
#define cublasGemmStridedBatchedEx( \
handle, \
transA, \
transB, \
m, \
n, \
k, \
alpha, \
A, \
Atype, \
lda, \
strideA, \
B, \
Btype, \
ldb, \
strideB, \
beta, \
C, \
Ctype, \
ldc, \
strideC, \
batchCount, \
computeType, \
algo \
) \
mublasGemmStridedBatchedEx( \
handle, \
transA, \
transB, \
m, \
n, \
k, \
alpha, \
A, \
Atype, \
lda, \
strideA, \
B, \
Btype, \
ldb, \
strideB, \
beta, \
C, \
Ctype, \
ldc, \
strideC, \
C /* D */, \
Ctype, \
ldc, \
strideC, \
batchCount, \
computeType, \
algo, \
0 /* solution type, reserved */, \
0 /* flags */ \
)
#define cublasGemmBatchedEx( \
handle, \
transA, \
transB, \
m, \
n, \
k, \
alpha, \
A, \
Atype, \
lda, \
B, \
Btype, \
ldb, \
beta, \
C, \
Ctype, \
ldc, \
batchCount, \
computeType, \
algo \
) \
mublasGemmBatchedEx( \
handle, \
transA, \
transB, \
m, \
n, \
k, \
alpha, \
A, \
Atype, \
lda, \
B, \
Btype, \
ldb, \
beta, \
C, \
Ctype, \
ldc, \
C /* D */, \
Ctype, \
ldc, \
batchCount, \
computeType, \
algo, \
0 /* solution type, reserved */, \
0 /* flags */ \
)
#define CUDART_VERSION MUSART_VERSION
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
// #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
// #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED MUBLAS_STATUS_NOT_IMPLEMENTED
#define CUBLAS_STATUS_ALLOC_FAILED MUBLAS_STATUS_NOT_IMPLEMENTED
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_TP32_TENSOR // ???
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_COMPUTE_16F MUSA_R_16F // reserved in musa
#define CUBLAS_COMPUTE_32F MUSA_R_32F // reserved in musa
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT_TENSOR_OP
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
#define cudaSuccess musaSuccess
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#endif // _MUSA_COMPATIBLE_CUH