From cab59819516467bc47ae35801b9f4400ca32dea6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 22 Jun 2024 15:02:42 +0200 Subject: [PATCH] only a single get_mma_tile_x_k function --- ggml-cuda/mmq.cuh | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 6c3b2a3ee..fed9b6fb0 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -124,11 +124,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \ 0 -static int mmq_get_mma_tile_x_k_host(const ggml_type type) { - MMQ_MMA_GET_TILE_X_K_BODY; -} - -static constexpr __device__ int mmq_get_mma_tile_x_k_device(ggml_type type) { +static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { MMQ_MMA_GET_TILE_X_K_BODY; } @@ -2424,9 +2420,10 @@ struct mmq_args { int64_t ne0; }; -static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y, const int cc) { +template +static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y); - const int mmq_tile_x_k = mmq_get_mma_tile_x_k_host(type); + const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); @@ -2441,7 +2438,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); - const int shmem = mmq_get_shmem(type, mmq_x, mmq_y, cc); + const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc); #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -2512,7 +2509,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) { const int granularity = mmq_get_granularity_host(mmq_x, cc); - if (mmq_x % granularity != 0 || mmq_get_shmem(type, mmq_x, mmq_y, cc) > smpbo) { + if (mmq_x % granularity != 0 || mmq_get_shmem(mmq_x, mmq_y, cc) > smpbo) { continue; }