CUDA => MUSA

Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
Xiaodong Ye 2024-07-16 16:31:30 +08:00
parent f085684038
commit 27396eac46
4 changed files with 234 additions and 14 deletions

View file

@ -6,6 +6,9 @@
#ifdef GGML_USE_HIPBLAS #ifdef GGML_USE_HIPBLAS
#define GGML_CUDA_NAME "ROCm" #define GGML_CUDA_NAME "ROCm"
#define GGML_CUBLAS_NAME "hipBLAS" #define GGML_CUBLAS_NAME "hipBLAS"
#elif defined(GGML_USE_MUSA)
#define GGML_CUDA_NAME "MUSA"
#define GGML_CUBLAS_NAME "muBLAS"
#else #else
#define GGML_CUDA_NAME "CUDA" #define GGML_CUDA_NAME "CUDA"
#define GGML_CUBLAS_NAME "cuBLAS" #define GGML_CUBLAS_NAME "cuBLAS"

View file

@ -19,7 +19,11 @@ typedef half2 ggml_half2;
#define GGML_COMMON_DECL #define GGML_COMMON_DECL
#elif defined(GGML_COMMON_DECL_CUDA) #elif defined(GGML_COMMON_DECL_CUDA)
#if defined(GGML_COMMON_DECL_MUSA)
#include <musa_fp16.h>
#else
#include <cuda_fp16.h> #include <cuda_fp16.h>
#endif
#include <cstdint> #include <cstdint>
typedef half ggml_half; typedef half ggml_half;
@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
#define GGML_TABLE_END() }; #define GGML_TABLE_END() };
#define GGML_COMMON_IMPL #define GGML_COMMON_IMPL
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) #elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
#include <cstdint> #include <cstdint>
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = { #define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {

View file

@ -167,7 +167,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
for (int id = 0; id < info.device_count; ++id) { for (int id = 0; id < info.device_count; ++id) {
int device_vmm = 0; int device_vmm = 0;
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
CUdevice device; CUdevice device;
CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@ -179,7 +179,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
alloc_prop.location.id = id; alloc_prop.location.id = id;
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
} }
#endif // !defined(GGML_USE_HIPBLAS) #endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
info.devices[id].vmm = !!device_vmm; info.devices[id].vmm = !!device_vmm;
cudaDeviceProp prop; cudaDeviceProp prop;
@ -315,7 +315,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
}; };
// pool with virtual memory // pool with virtual memory
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
struct ggml_cuda_pool_vmm : public ggml_cuda_pool { struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
@ -409,14 +409,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
} }
}; };
#endif // !defined(GGML_USE_HIPBLAS) #endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) { std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
if (ggml_cuda_info().devices[device].vmm) { if (ggml_cuda_info().devices[device].vmm) {
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device)); return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
} }
#endif #endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device)); return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
} }
@ -1341,7 +1341,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
static cudaError_t ggml_cuda_Memcpy2DPeerAsync( static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
#if !defined(GGML_USE_HIPBLAS) #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
cudaMemcpy3DPeerParms p = {}; cudaMemcpy3DPeerParms p = {};
p.dstDevice = dstDevice; p.dstDevice = dstDevice;
@ -1355,7 +1355,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
GGML_UNUSED(dstDevice); GGML_UNUSED(dstDevice);
GGML_UNUSED(srcDevice); GGML_UNUSED(srcDevice);
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
#endif // !defined(GGML_USE_HIPBLAS) #endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
} }
static void ggml_cuda_op_mul_mat( static void ggml_cuda_op_mul_mat(
@ -1821,6 +1821,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
} }
} }
#else #else
#ifdef GGML_USE_MUSA
GGML_ASSERT(false);
#else // !GGML_USE_MUSA
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx // use cublasGemmStridedBatchedEx
@ -1863,6 +1866,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cu_compute_type, cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} }
#endif // GGML_USE_MUSA
#endif #endif
if (dst->op_params[0] == GGML_PREC_DEFAULT) { if (dst->op_params[0] == GGML_PREC_DEFAULT) {
@ -3019,7 +3023,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
return false; return false;
} }
#if CUDART_VERSION >= 11100 #if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
if (err != cudaSuccess) { if (err != cudaSuccess) {
// clear the error // clear the error

View file

@ -12,6 +12,10 @@
#else #else
#define GGML_COMMON_DECL_CUDA #define GGML_COMMON_DECL_CUDA
#define GGML_COMMON_IMPL_CUDA #define GGML_COMMON_IMPL_CUDA
#if defined(GGML_USE_MUSA)
#define GGML_COMMON_DECL_MUSA
#define GGML_COMMON_IMPL_MUSA
#endif
#endif #endif
#include "ggml-common.h" #include "ggml-common.h"
@ -114,6 +118,151 @@
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#elif defined(GGML_USE_MUSA)
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_fp16.h>
// XXX: Keep the following order the same as hipBLAS
// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
// #define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_32F MUSA_R_32F
// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
// #define cublasComputeType_t mublasComputeType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaSuccess musaSuccess
// XXX: Other CUDA => MUSA mapping
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// XXX: USE_CUDA_GRAPH
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
// XXX: cuBLAS => muBLAS mapping
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
// XXX: Clang builtins mapping
#define __vsubss4 __vsubss4_musa
#define __vsub4 __vsub4_musa
#define __vcmpeq4 __vcmpeq4_musa
#define __vcmpne4 __vcmpne4_musa
#else #else
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda.h> #include <cuda.h>
@ -168,9 +317,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
#if CUDART_VERSION >= 12000 #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
static const char * cublas_get_error_str(const cublasStatus_t err) { static const char * cublas_get_error_str(const mublasStatus_t err) {
#ifndef GGML_USE_MUSA
return cublasGetStatusString(err); return cublasGetStatusString(err);
#else
return mublasStatus_to_string(err);
#endif // GGML_USE_MUSA
} }
#else #else
static const char * cublas_get_error_str(const cublasStatus_t err) { static const char * cublas_get_error_str(const cublasStatus_t err) {
@ -200,7 +353,7 @@ static const char * cu_get_error_str(CUresult err) {
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif #endif
#if CUDART_VERSION >= 11100 #if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
#define GGML_CUDA_ASSUME(x) __builtin_assume(x) #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else #else
#define GGML_CUDA_ASSUME(x) #define GGML_CUDA_ASSUME(x)
@ -214,6 +367,62 @@ typedef float dfloat; // dequantize float
typedef float2 dfloat2; typedef float2 dfloat2;
#endif //GGML_CUDA_F16 #endif //GGML_CUDA_F16
#if defined(GGML_USE_MUSA)
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
static __device__ __forceinline__ int __vsubss4_musa(const int a, const int b) {
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
#if __has_builtin(__builtin_elementwise_sub_sat)
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
return reinterpret_cast<const int &>(c);
#else
int8x4_t c;
int16_t tmp;
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp = va[i] - vb[i];
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
c[i] = tmp;
}
return reinterpret_cast<int &>(c);
#endif // __has_builtin(__builtin_elementwise_sub_sat)
}
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
return __vsubss4_musa(a, b);
}
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
unsigned int c;
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
#pragma unroll
for (int i = 0; i < 4; ++i) {
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
}
return c;
}
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
unsigned int c;
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
#pragma unroll
for (int i = 0; i < 4; ++i) {
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
}
return c;
}
#endif // defined(GGML_USE_MUSA)
#if defined(GGML_USE_HIPBLAS) #if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300 #define __CUDA_ARCH__ 1300
@ -455,7 +664,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
return mask_low | mask_high; return mask_low | mask_high;
} }
#endif // CUDART_VERSION < 12000 #endif // CUDART_VERSION < CUDART_HMASK
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)