feat: Support Moore Threads GPU (#8383)
* Update doc for MUSA Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Add GGML_MUSA in Makefile Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Add GGML_MUSA in CMake Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * CUDA => MUSA Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * MUSA adds support for __vsubss4 Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Fix CI build failure Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> --------- Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
parent
5e2727fe03
commit
e54c35e4fb
9 changed files with 329 additions and 30 deletions
|
@ -12,6 +12,10 @@
|
|||
#else
|
||||
#define GGML_COMMON_DECL_CUDA
|
||||
#define GGML_COMMON_IMPL_CUDA
|
||||
#if defined(GGML_USE_MUSA)
|
||||
#define GGML_COMMON_DECL_MUSA
|
||||
#define GGML_COMMON_IMPL_MUSA
|
||||
#endif
|
||||
#endif
|
||||
#include "ggml-common.h"
|
||||
|
||||
|
@ -114,6 +118,150 @@
|
|||
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
#include <musa_runtime.h>
|
||||
#include <musa.h>
|
||||
#include <mublas.h>
|
||||
#include <musa_fp16.h>
|
||||
// XXX: Keep the following order the same as hipBLAS
|
||||
// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
|
||||
// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
||||
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N MUBLAS_OP_N
|
||||
#define CUBLAS_OP_T MUBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
||||
// #define CUBLAS_TF32_TENSOR_OP_MATH 0
|
||||
#define CUDA_R_16F MUSA_R_16F
|
||||
#define CUDA_R_32F MUSA_R_32F
|
||||
// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
// #define cublasComputeType_t mublasComputeType_t
|
||||
#define cublasCreate mublasCreate
|
||||
#define cublasDestroy mublasDestroy
|
||||
#define cublasGemmEx mublasGemmEx
|
||||
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
||||
#define cublasHandle_t mublasHandle_t
|
||||
// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaError_t musaError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
||||
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
||||
#define cudaEventDisableTiming musaEventDisableTiming
|
||||
#define cudaEventRecord musaEventRecord
|
||||
#define cudaEventSynchronize musaEventSynchronize
|
||||
#define cudaEvent_t musaEvent_t
|
||||
#define cudaEventDestroy musaEventDestroy
|
||||
#define cudaFree musaFree
|
||||
#define cudaFreeHost musaFreeHost
|
||||
#define cudaGetDevice musaGetDevice
|
||||
#define cudaGetDeviceCount musaGetDeviceCount
|
||||
#define cudaGetDeviceProperties musaGetDeviceProperties
|
||||
#define cudaGetErrorString musaGetErrorString
|
||||
#define cudaGetLastError musaGetLastError
|
||||
#define cudaHostRegister musaHostRegister
|
||||
#define cudaHostRegisterPortable musaHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
||||
#define cudaHostUnregister musaHostUnregister
|
||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||
#define cudaMalloc musaMalloc
|
||||
#define cudaMallocHost musaMallocHost
|
||||
#define cudaMemcpy musaMemcpy
|
||||
#define cudaMemcpyAsync musaMemcpyAsync
|
||||
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
||||
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
||||
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
||||
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
||||
#define cudaMemcpyKind musaMemcpyKind
|
||||
#define cudaMemset musaMemset
|
||||
#define cudaMemsetAsync musaMemsetAsync
|
||||
#define cudaMemGetInfo musaMemGetInfo
|
||||
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice musaSetDevice
|
||||
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
||||
#define cudaStreamDestroy musaStreamDestroy
|
||||
#define cudaStreamFireAndForget musaStreamFireAndForget
|
||||
#define cudaStreamNonBlocking musaStreamNonBlocking
|
||||
#define cudaStreamPerThread musaStreamPerThread
|
||||
#define cudaStreamSynchronize musaStreamSynchronize
|
||||
#define cudaStreamWaitEvent musaStreamWaitEvent
|
||||
#define cudaStream_t musaStream_t
|
||||
#define cudaSuccess musaSuccess
|
||||
|
||||
// XXX: Other CUDA => MUSA mapping
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
||||
#define CUdevice MUdevice
|
||||
#define CUdeviceptr MUdeviceptr
|
||||
#define CUmemAccessDesc MUmemAccessDesc
|
||||
#define CUmemAllocationProp MUmemAllocationProp
|
||||
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
||||
#define cuDeviceGet muDeviceGet
|
||||
#define cuDeviceGetAttribute muDeviceGetAttribute
|
||||
#define cuMemAddressFree muMemAddressFree
|
||||
#define cuMemAddressReserve muMemAddressReserve
|
||||
#define cuMemCreate muMemCreate
|
||||
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
||||
#define cuMemMap muMemMap
|
||||
#define cuMemRelease muMemRelease
|
||||
#define cuMemSetAccess muMemSetAccess
|
||||
#define cuMemUnmap muMemUnmap
|
||||
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
||||
#define cudaFuncSetAttribute musaFuncSetAttribute
|
||||
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
||||
#define make_cudaExtent make_musaExtent
|
||||
#define make_cudaPitchedPtr make_musaPitchedPtr
|
||||
|
||||
// XXX: USE_CUDA_GRAPH
|
||||
#define CUDA_SUCCESS MUSA_SUCCESS
|
||||
#define CUresult MUresult
|
||||
#define cuGetErrorString muGetErrorString
|
||||
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
||||
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
||||
#define cudaGraphDestroy musaGraphDestroy
|
||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||
#define cudaGraphExec_t musaGraphExec_t
|
||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
||||
#define cudaGraphGetNodes musaGraphGetNodes
|
||||
#define cudaGraphInstantiate musaGraphInstantiate
|
||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
||||
#define cudaGraphLaunch musaGraphLaunch
|
||||
#define cudaGraphNodeGetType musaGraphNodeGetType
|
||||
#define cudaGraphNode_t musaGraphNode_t
|
||||
#define cudaGraphNodeType musaGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
||||
#define cudaGraph_t musaGraph_t
|
||||
#define cudaKernelNodeParams musaKernelNodeParams
|
||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||
#define cudaStreamEndCapture musaStreamEndCapture
|
||||
|
||||
// XXX: cuBLAS => muBLAS mapping
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||
#define cublasComputeType_t cudaDataType_t
|
||||
|
||||
// XXX: Clang builtins mapping
|
||||
#define __vsub4 __vsub4_musa
|
||||
#define __vcmpeq4 __vcmpeq4_musa
|
||||
#define __vcmpne4 __vcmpne4_musa
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
|
@ -168,9 +316,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
|
|||
|
||||
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
|
||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
#ifndef GGML_USE_MUSA
|
||||
return cublasGetStatusString(err);
|
||||
#else
|
||||
return mublasStatus_to_string(err);
|
||||
#endif // GGML_USE_MUSA
|
||||
}
|
||||
#else
|
||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
|
@ -200,7 +352,7 @@ static const char * cu_get_error_str(CUresult err) {
|
|||
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
||||
#endif
|
||||
|
||||
#if CUDART_VERSION >= 11100
|
||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||
#else
|
||||
#define GGML_CUDA_ASSUME(x)
|
||||
|
@ -214,6 +366,42 @@ typedef float dfloat; // dequantize float
|
|||
typedef float2 dfloat2;
|
||||
#endif //GGML_CUDA_F16
|
||||
|
||||
#if defined(GGML_USE_MUSA)
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
#endif
|
||||
|
||||
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
|
||||
return __vsubss4(a, b);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
|
||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||
unsigned int c;
|
||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
|
||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||
unsigned int c;
|
||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
#endif // defined(GGML_USE_MUSA)
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
|
@ -455,7 +643,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
|
|||
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
||||
return mask_low | mask_high;
|
||||
}
|
||||
#endif // CUDART_VERSION < 12000
|
||||
#endif // CUDART_VERSION < CUDART_HMASK
|
||||
|
||||
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue