Renames the rest of the compute capability macros for consistency.
This commit is contained in:
parent
3974cf6bdf
commit
b20f068234
6 changed files with 29 additions and 29 deletions
|
@ -46,23 +46,23 @@
|
|||
#define GGML_CUDA_CC_VOLTA 700
|
||||
#define GGML_CUDA_CC_TURING 750
|
||||
#define GGML_CUDA_CC_AMPERE 800
|
||||
#define CC_OFFSET_AMD 1000000
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 1000000
|
||||
|
||||
// GCN/CNDA, wave size is 64
|
||||
#define CC_GCN4 (CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||
#define CC_VEGA (CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
|
||||
#define CC_VEGA20 (CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
|
||||
#define CC_CDNA (CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
|
||||
#define CC_CDNA2 (CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
|
||||
#define CC_CDNA3 (CC_OFFSET_AMD + 942) // MI300
|
||||
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
|
||||
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
|
||||
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
|
||||
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
|
||||
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
|
||||
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 942) // MI300
|
||||
|
||||
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
|
||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010) // RX 5000
|
||||
#define CC_RDNA2 (CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
|
||||
#define CC_RDNA3 (CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
|
||||
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
|
||||
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
|
||||
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
|
||||
|
||||
#define CC_QY1 210
|
||||
#define CC_QY2 220
|
||||
#define GGML_CUDA_CC_QY1 210
|
||||
#define GGML_CUDA_CC_QY2 220
|
||||
|
||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||
|
||||
|
@ -147,20 +147,20 @@ typedef float2 dfloat2;
|
|||
#define INT8_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
|
||||
|
||||
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
|
||||
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
#define FLASH_ATTN_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
|
||||
#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
|
||||
|
||||
static constexpr bool fast_fp16_available(const int cc) {
|
||||
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
|
||||
}
|
||||
|
||||
static constexpr bool fp16_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
|
||||
}
|
||||
|
||||
static constexpr bool int8_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
|
||||
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
|
|
|
@ -304,7 +304,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
|
|
|
@ -177,7 +177,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlock;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
|
||||
#else
|
||||
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
|
||||
info.devices[id].cc = 100*prop.major + 10*prop.minor;
|
||||
|
@ -1108,7 +1108,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
|||
const half beta_f16 = 0.0f;
|
||||
|
||||
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||
if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
|
||||
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
}
|
||||
|
||||
|
@ -1612,7 +1612,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
|
||||
cudaDataType_t cu_data_type = CUDA_R_16F;
|
||||
|
||||
if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
|
||||
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
|
||||
cu_compute_type = CUBLAS_COMPUTE_32F;
|
||||
}
|
||||
|
||||
|
@ -3028,7 +3028,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return true;
|
||||
}
|
||||
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
|
||||
return cc >= GGML_CUDA_CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
}
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
|
|
|
@ -27,7 +27,7 @@ void ggml_cuda_op_mul_mat_q(
|
|||
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
|
||||
// Also its fixup needs to allocate a temporary buffer in the memory pool.
|
||||
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
|
||||
const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < CC_OFFSET_AMD && src1_ncols == ne11;
|
||||
const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
|
||||
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
|
||||
|
||||
switch (src0->type) {
|
||||
|
@ -144,9 +144,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|||
return true;
|
||||
#endif //GGML_CUDA_FORCE_MMQ
|
||||
|
||||
if (cc < CC_OFFSET_AMD) {
|
||||
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
|
||||
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
return (cc < CC_RDNA3 && cc != CC_CDNA && cc != CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
|
|
@ -89,9 +89,9 @@ struct tile_x_sizes {
|
|||
static constexpr int get_mmq_x_max_host(const int cc) {
|
||||
return int8_mma_available(cc) ? 128 :
|
||||
#ifdef GGML_CUDA_FORCE_MMQ
|
||||
cc >= GGML_CUDA_CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
|
||||
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128 : 64;
|
||||
#else
|
||||
cc >= GGML_CUDA_CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
|
||||
cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
|
||||
#endif // GGML_CUDA_FORCE_MMQ
|
||||
}
|
||||
|
||||
|
@ -120,7 +120,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|||
}
|
||||
|
||||
static constexpr int get_mmq_y_host(const int cc) {
|
||||
return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
|
||||
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (cc == GGML_CUDA_CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
|
||||
}
|
||||
|
||||
static constexpr __device__ int get_mmq_y_device() {
|
||||
|
@ -2825,7 +2825,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
|||
const int mmq_x_max = get_mmq_x_max_host(cc);
|
||||
const int mmq_y = get_mmq_y_host(cc);
|
||||
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
|
||||
const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < CC_OFFSET_AMD;
|
||||
const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
|
||||
|
||||
int mmq_x_best = 0;
|
||||
int nparts_best = INT_MAX;
|
||||
|
|
|
@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
|
|||
int64_t nwarps = 1;
|
||||
int64_t rows_per_cuda_block = 1;
|
||||
|
||||
if (ggml_cuda_info().devices[id].cc < CC_CDNA || ggml_cuda_info().devices[id].cc == CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
|
||||
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
|
||||
switch(ncols_y) {
|
||||
case 1:
|
||||
nwarps = 4;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue