CUDA: switch tile sizes based on binary version
This commit is contained in:
parent
d250c9d61d
commit
44a80b4119
1 changed files with 13 additions and 3 deletions
16
ggml-cuda.cu
16
ggml-cuda.cu
|
@ -6941,11 +6941,12 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||||
|
|
||||||
|
int mmq_x, mmq_y, nwarps;
|
||||||
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
int id;
|
int id;
|
||||||
CUDA_CHECK(cudaGetDevice(&id));
|
CUDA_CHECK(cudaGetDevice(&id));
|
||||||
const int compute_capability = g_device_caps[id].cc;
|
const int compute_capability = g_device_caps[id].cc;
|
||||||
|
|
||||||
int mmq_x, mmq_y, nwarps;
|
|
||||||
if (compute_capability >= CC_RDNA2) {
|
if (compute_capability >= CC_RDNA2) {
|
||||||
mmq_x = MMQ_X_Q4_0_RDNA2;
|
mmq_x = MMQ_X_Q4_0_RDNA2;
|
||||||
mmq_y = MMQ_Y_Q4_0_RDNA2;
|
mmq_y = MMQ_Y_Q4_0_RDNA2;
|
||||||
|
@ -6954,17 +6955,26 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||||
mmq_x = MMQ_X_Q4_0_RDNA1;
|
mmq_x = MMQ_X_Q4_0_RDNA1;
|
||||||
mmq_y = MMQ_Y_Q4_0_RDNA1;
|
mmq_y = MMQ_Y_Q4_0_RDNA1;
|
||||||
nwarps = NWARPS_Q4_0_RDNA1;
|
nwarps = NWARPS_Q4_0_RDNA1;
|
||||||
} else if (compute_capability >= CC_VOLTA) {
|
} else {
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
cudaFuncAttributes attributes;
|
||||||
|
CUDA_CHECK(cudaFuncGetAttributes(&attributes, mul_mat_q4_0<false>));
|
||||||
|
const int cc_binary = 10*attributes.binaryVersion;
|
||||||
|
|
||||||
|
if (cc_binary >= CC_VOLTA) {
|
||||||
mmq_x = MMQ_X_Q4_0_AMPERE;
|
mmq_x = MMQ_X_Q4_0_AMPERE;
|
||||||
mmq_y = MMQ_Y_Q4_0_AMPERE;
|
mmq_y = MMQ_Y_Q4_0_AMPERE;
|
||||||
nwarps = NWARPS_Q4_0_AMPERE;
|
nwarps = NWARPS_Q4_0_AMPERE;
|
||||||
} else if (compute_capability >= MIN_CC_DP4A) {
|
} else if (cc_binary >= MIN_CC_DP4A) {
|
||||||
mmq_x = MMQ_X_Q4_0_PASCAL;
|
mmq_x = MMQ_X_Q4_0_PASCAL;
|
||||||
mmq_y = MMQ_Y_Q4_0_PASCAL;
|
mmq_y = MMQ_Y_Q4_0_PASCAL;
|
||||||
nwarps = NWARPS_Q4_0_PASCAL;
|
nwarps = NWARPS_Q4_0_PASCAL;
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue