Merge 44a80b4119
into 7733f0c760
This commit is contained in:
commit
743dd102b1
1 changed files with 13 additions and 3 deletions
16
ggml-cuda.cu
16
ggml-cuda.cu
|
@ -8079,11 +8079,12 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
int mmq_x, mmq_y, nwarps;
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
const int compute_capability = get_cuda_global_info().devices[id].cc;
|
||||
|
||||
int mmq_x, mmq_y, nwarps;
|
||||
if (compute_capability >= CC_RDNA2) {
|
||||
mmq_x = MMQ_X_Q4_0_RDNA2;
|
||||
mmq_y = MMQ_Y_Q4_0_RDNA2;
|
||||
|
@ -8092,17 +8093,26 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
|||
mmq_x = MMQ_X_Q4_0_RDNA1;
|
||||
mmq_y = MMQ_Y_Q4_0_RDNA1;
|
||||
nwarps = NWARPS_Q4_0_RDNA1;
|
||||
} else if (compute_capability >= CC_VOLTA) {
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
#else
|
||||
cudaFuncAttributes attributes;
|
||||
CUDA_CHECK(cudaFuncGetAttributes(&attributes, mul_mat_q4_0<false>));
|
||||
const int cc_binary = 10*attributes.binaryVersion;
|
||||
|
||||
if (cc_binary >= CC_VOLTA) {
|
||||
mmq_x = MMQ_X_Q4_0_AMPERE;
|
||||
mmq_y = MMQ_Y_Q4_0_AMPERE;
|
||||
nwarps = NWARPS_Q4_0_AMPERE;
|
||||
} else if (compute_capability >= MIN_CC_DP4A) {
|
||||
} else if (cc_binary >= MIN_CC_DP4A) {
|
||||
mmq_x = MMQ_X_Q4_0_PASCAL;
|
||||
mmq_y = MMQ_Y_Q4_0_PASCAL;
|
||||
nwarps = NWARPS_Q4_0_PASCAL;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue