Add musa support
This commit is contained in:
parent
5c4d767ac0
commit
b7499e0460
5 changed files with 282 additions and 12 deletions
32
Makefile
32
Makefile
|
@ -565,6 +565,38 @@ ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/com
|
|||
|
||||
endif # LLAMA_HIPBLAS
|
||||
|
||||
ifdef LLAMA_MUSA
|
||||
MUSA_PATH ?= /usr/local/musa
|
||||
MUSA_ARCH ?= 10
|
||||
MCC ?= $(CCACHE) $(MUSA_PATH)/bin/mcc
|
||||
LLAMA_CUDA_DMMV_X ?= 32
|
||||
LLAMA_CUDA_MMV_Y ?= 1
|
||||
LLAMA_CUDA_KQUANTS_ITER ?= 2
|
||||
MK_CPPFLAGS += -DGGML_USE_MUSA -DGGML_USE_CUDA
|
||||
MK_LDFLAGS += -L$(MUSA_PATH)/lib -Wl,-rpath=$(MUSA_PATH)/lib
|
||||
MK_LDFLAGS += -lmublas -lmusa -lmusart
|
||||
MUSAFLAGS += --cuda-gpu-arch=mp_$(MUSA_ARCH)
|
||||
MUSAFLAGS += -Wno-unknown-warning-option -Wno-gnu-anonymous-struct -Wno-nested-anon-types -Wno-invalid-noreturn
|
||||
MUSAFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
||||
MUSAFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
|
||||
MUSAFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
|
||||
ifdef LLAMA_CUDA_FORCE_DMMV
|
||||
MUSAFLAGS += -DGGML_CUDA_FORCE_DMMV
|
||||
endif # LLAMA_CUDA_FORCE_DMMV
|
||||
ifdef LLAMA_CUDA_NO_PEER_COPY
|
||||
MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY
|
||||
endif # LLAMA_CUDA_NO_PEER_COPY
|
||||
OBJS += ggml-cuda.o
|
||||
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
|
||||
|
||||
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
|
||||
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -c -o $@ $<
|
||||
|
||||
ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
|
||||
$(MCC) $(CXXFLAGS) $(MUSAFLAGS) -x musa -c -o $@ $<
|
||||
|
||||
endif # LLAMA_MUSA
|
||||
|
||||
ifdef LLAMA_METAL
|
||||
MK_CPPFLAGS += -DGGML_USE_METAL
|
||||
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
|
||||
|
|
|
@ -17,6 +17,15 @@ typedef half2 ggml_half2;
|
|||
|
||||
#define GGML_COMMON_AGGR
|
||||
|
||||
#define GGML_COMMON_DECL
|
||||
#elif defined(GGML_COMMON_DECL_MUSA)
|
||||
#include <mublas.h>
|
||||
|
||||
typedef half ggml_half;
|
||||
typedef half2 ggml_half2;
|
||||
|
||||
#define GGML_COMMON_AGGR
|
||||
|
||||
#define GGML_COMMON_DECL
|
||||
#elif defined(GGML_COMMON_DECL_CUDA)
|
||||
#include <cuda_fp16.h>
|
||||
|
@ -73,7 +82,7 @@ typedef sycl::half2 ggml_half2;
|
|||
#define K_SCALE_SIZE 12
|
||||
#endif // GGML_QKK_64
|
||||
|
||||
#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL)
|
||||
#if defined(GGML_COMMON_DECL_CUDA) || defined(GGML_COMMON_DECL_HIP) || defined(GGML_COMMON_DECL_SYCL) || defined(GGML_COMMON_DECL_MUSA)
|
||||
// QR = QK / number of values before dequantization
|
||||
// QI = number of 32 bit integers before dequantization
|
||||
|
||||
|
@ -439,7 +448,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
|
|||
#define GGML_TABLE_END() };
|
||||
|
||||
#define GGML_COMMON_IMPL
|
||||
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
|
||||
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
|
||||
#include <cstdint>
|
||||
|
||||
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
|
||||
|
|
20
ggml-cuda.cu
20
ggml-cuda.cu
|
@ -112,7 +112,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
for (int id = 0; id < info.device_count; ++id) {
|
||||
int device_vmm = 0;
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
CUdevice device;
|
||||
CU_CHECK(cuDeviceGet(&device, id));
|
||||
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
||||
|
@ -124,7 +124,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
alloc_prop.location.id = id;
|
||||
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIPBLAS)
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
info.devices[id].vmm = !!device_vmm;
|
||||
|
||||
cudaDeviceProp prop;
|
||||
|
@ -257,7 +257,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
|||
};
|
||||
|
||||
// pool with virtual memory
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
||||
|
||||
|
@ -351,10 +351,10 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
|||
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
|
||||
}
|
||||
};
|
||||
#endif // !defined(GGML_USE_HIPBLAS)
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
if (ggml_cuda_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||
}
|
||||
|
@ -1596,7 +1596,7 @@ static void ggml_cuda_op_mul_mat(
|
|||
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
|
||||
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
|
||||
dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = ctx.device;
|
||||
|
@ -1793,7 +1793,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
|
||||
#if 0
|
||||
#if defined(GGML_USE_MUSA)
|
||||
// use cublasGemmEx
|
||||
{
|
||||
for (int i13 = 0; i13 < ne13; ++i13) {
|
||||
|
@ -1802,10 +1802,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||
int i02 = i12 / r2;
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
||||
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
||||
alpha, (const char *) src0_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
|
||||
(const char *) src1_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
|
||||
beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
|
|
|
@ -8,6 +8,9 @@
|
|||
#if defined(GGML_USE_HIPBLAS)
|
||||
#define GGML_COMMON_DECL_HIP
|
||||
#define GGML_COMMON_IMPL_HIP
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
#define GGML_COMMON_DECL_MUSA
|
||||
#define GGML_COMMON_IMPL_MUSA
|
||||
#else
|
||||
#define GGML_COMMON_DECL_CUDA
|
||||
#define GGML_COMMON_IMPL_CUDA
|
||||
|
@ -117,6 +120,10 @@
|
|||
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
#include <musa.h>
|
||||
#include <mublas.h>
|
||||
#include "musa_compatible.cuh"
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
|
@ -189,6 +196,26 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
|
|||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
return cublasGetStatusString(err);
|
||||
}
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
switch (err) {
|
||||
case MUBLAS_STATUS_SUCCESS: return "MUBLAS_STATUS_SUCCESS";
|
||||
case MUBLAS_STATUS_INVALID_HANDLE: return "MUBLAS_STATUS_INVALID_HANDLE";
|
||||
case MUBLAS_STATUS_NOT_IMPLEMENTED: return "MUBLAS_STATUS_NOT_IMPLEMENTED";
|
||||
case MUBLAS_STATUS_INVALID_POINTER: return "MUBLAS_STATUS_INVALID_POINTER";
|
||||
case MUBLAS_STATUS_INVALID_SIZE: return "MUBLAS_STATUS_INVALID_SIZE";
|
||||
case MUBLAS_STATUS_MEMORY_ERROR: return "MUBLAS_STATUS_MEMORY_ERROR";
|
||||
case MUBLAS_STATUS_INTERNAL_ERROR: return "MUBLAS_STATUS_INTERNAL_ERROR";
|
||||
case MUBLAS_STATUS_PERF_DEGRADED: return "MUBLAS_STATUS_PERF_DEGRADED";
|
||||
case MUBLAS_STATUS_SIZE_QUERY_MISMATCH: return "MUBLAS_STATUS_SIZE_QUERY_MISMATCH";
|
||||
case MUBLAS_STATUS_SIZE_INCREASED: return "MUBLAS_STATUS_SIZE_INCREASED";
|
||||
case MUBLAS_STATUS_SIZE_UNCHANGED: return "MUBLAS_STATUS_SIZE_UNCHANGED";
|
||||
case MUBLAS_STATUS_INVALID_VALUE: return "MUBLAS_STATUS_INVALID_VALUE";
|
||||
case MUBLAS_STATUS_CONTINUE: return "MUBLAS_STATUS_CONTINUE";
|
||||
|
||||
default: return "unknown error";
|
||||
}
|
||||
}
|
||||
#else
|
||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
switch (err) {
|
||||
|
|
202
ggml-cuda/musa_compatible.cuh
Normal file
202
ggml-cuda/musa_compatible.cuh
Normal file
|
@ -0,0 +1,202 @@
|
|||
|
||||
#ifndef _MUSA_COMPATIBLE_CUH
|
||||
#define _MUSA_COMPATIBLE_CUH
|
||||
|
||||
|
||||
#define CUresult MUresult
|
||||
#define CUdevice MUdevice
|
||||
#define CUdeviceptr MUdeviceptr
|
||||
|
||||
#define cudaDataType_t musaDataType_t
|
||||
#define cudaError_t musaError_t
|
||||
#define cudaEvent_t musaEvent_t
|
||||
#define cudaStream_t musaStream_t
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasHandle_t mublasHandle_t
|
||||
#define cublasComputeType_t musaDataType_t // reserved in musa
|
||||
|
||||
#define cuGetErrorString muGetErrorString
|
||||
#define cuDeviceGet muDeviceGet
|
||||
#define cuDeviceGetAttribute muDeviceGetAttribute
|
||||
// #define cuMemGetAllocationGranularity muMemGetAllocationGranularity // so far, not implemeted
|
||||
// #define CUmemAllocationProp MUmemAllocationProp
|
||||
|
||||
#define cudaGetErrorString musaGetErrorString
|
||||
#define cudaGetLastError musaGetLastError
|
||||
#define cudaMemGetInfo musaMemGetInfo
|
||||
#define cudaMemset musaMemset
|
||||
#define cudaMalloc musaMalloc
|
||||
#define cudaMallocHost musaMallocHost
|
||||
#define cudaFree musaFree
|
||||
#define cudaFreeHost musaFreeHost
|
||||
#define cudaHostUnregister musaHostUnregister
|
||||
#define cudaMemcpyAsync musaMemcpyAsync
|
||||
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
||||
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
||||
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
||||
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
||||
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaGetDevice musaGetDevice
|
||||
#define cudaGetDeviceCount musaGetDeviceCount
|
||||
#define cudaGetDeviceProperties musaGetDeviceProperties
|
||||
#define cudaSetDevice musaSetDevice
|
||||
#define cudaEventRecord musaEventRecord
|
||||
#define cudaEventDestroy musaEventDestroy
|
||||
#define cudaEventCreate musaEventCreate
|
||||
#define cudaEventSynchronize musaEventSynchronize
|
||||
#define cudaEventDisableTiming musaEventDisableTiming
|
||||
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
||||
#define cudaStreamPerThread musaStreamPerThread
|
||||
#define cudaStreamSynchronize musaStreamSynchronize
|
||||
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
||||
#define cudaStreamNonBlocking musaStreamNonBlocking
|
||||
#define cudaStreamDestroy musaStreamDestroy
|
||||
#define cudaStreamWaitEvent musaStreamWaitEvent
|
||||
|
||||
#define cublasCreate mublasCreate
|
||||
#define cublasDestroy mublasDestroy
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasGemmEx mublasGemmEx
|
||||
#define cublasSgemm mublasSgemm
|
||||
#ifdef mublasGemmStridedBatchedEx
|
||||
#undef mublasGemmStridedBatchedEx
|
||||
#endif // mublasGemmStridedBatchedEx
|
||||
#define cublasGemmStridedBatchedEx( \
|
||||
handle, \
|
||||
transA, \
|
||||
transB, \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
alpha, \
|
||||
A, \
|
||||
Atype, \
|
||||
lda, \
|
||||
strideA, \
|
||||
B, \
|
||||
Btype, \
|
||||
ldb, \
|
||||
strideB, \
|
||||
beta, \
|
||||
C, \
|
||||
Ctype, \
|
||||
ldc, \
|
||||
strideC, \
|
||||
batchCount, \
|
||||
computeType, \
|
||||
algo \
|
||||
) \
|
||||
mublasGemmStridedBatchedEx( \
|
||||
handle, \
|
||||
transA, \
|
||||
transB, \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
alpha, \
|
||||
A, \
|
||||
Atype, \
|
||||
lda, \
|
||||
strideA, \
|
||||
B, \
|
||||
Btype, \
|
||||
ldb, \
|
||||
strideB, \
|
||||
beta, \
|
||||
C, \
|
||||
Ctype, \
|
||||
ldc, \
|
||||
strideC, \
|
||||
C /* D */, \
|
||||
Ctype, \
|
||||
ldc, \
|
||||
strideC, \
|
||||
batchCount, \
|
||||
computeType, \
|
||||
algo, \
|
||||
0 /* solution type, reserved */, \
|
||||
0 /* flags */ \
|
||||
)
|
||||
|
||||
#define cublasGemmBatchedEx( \
|
||||
handle, \
|
||||
transA, \
|
||||
transB, \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
alpha, \
|
||||
A, \
|
||||
Atype, \
|
||||
lda, \
|
||||
B, \
|
||||
Btype, \
|
||||
ldb, \
|
||||
beta, \
|
||||
C, \
|
||||
Ctype, \
|
||||
ldc, \
|
||||
batchCount, \
|
||||
computeType, \
|
||||
algo \
|
||||
) \
|
||||
mublasGemmBatchedEx( \
|
||||
handle, \
|
||||
transA, \
|
||||
transB, \
|
||||
m, \
|
||||
n, \
|
||||
k, \
|
||||
alpha, \
|
||||
A, \
|
||||
Atype, \
|
||||
lda, \
|
||||
B, \
|
||||
Btype, \
|
||||
ldb, \
|
||||
beta, \
|
||||
C, \
|
||||
Ctype, \
|
||||
ldc, \
|
||||
C /* D */, \
|
||||
Ctype, \
|
||||
ldc, \
|
||||
batchCount, \
|
||||
computeType, \
|
||||
algo, \
|
||||
0 /* solution type, reserved */, \
|
||||
0 /* flags */ \
|
||||
)
|
||||
|
||||
#define CUDART_VERSION MUSART_VERSION
|
||||
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
||||
// #define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
||||
// #define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED
|
||||
|
||||
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_STATUS_NOT_INITIALIZED MUBLAS_STATUS_NOT_IMPLEMENTED
|
||||
#define CUBLAS_STATUS_ALLOC_FAILED MUBLAS_STATUS_NOT_IMPLEMENTED
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_TP32_TENSOR // ???
|
||||
#define CUBLAS_OP_T MUBLAS_OP_T
|
||||
#define CUBLAS_OP_N MUBLAS_OP_N
|
||||
#define CUBLAS_COMPUTE_16F MUSA_R_16F // reserved in musa
|
||||
#define CUBLAS_COMPUTE_32F MUSA_R_32F // reserved in musa
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT_TENSOR_OP
|
||||
|
||||
#define CUDA_SUCCESS MUSA_SUCCESS
|
||||
#define CUDA_R_16F MUSA_R_16F
|
||||
#define CUDA_R_32F MUSA_R_32F
|
||||
#define cudaSuccess musaSuccess
|
||||
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
||||
|
||||
#endif // _MUSA_COMPATIBLE_CUH
|
Loading…
Add table
Add a link
Reference in a new issue