only a single get_mma_tile_x_k function

This commit is contained in:
Johannes Gäßler 2024-06-22 15:02:42 +02:00
parent db6dae797b
commit cab5981951

View file

@ -124,11 +124,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \ type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
0 0
static int mmq_get_mma_tile_x_k_host(const ggml_type type) { static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
MMQ_MMA_GET_TILE_X_K_BODY;
}
static constexpr __device__ int mmq_get_mma_tile_x_k_device(ggml_type type) {
MMQ_MMA_GET_TILE_X_K_BODY; MMQ_MMA_GET_TILE_X_K_BODY;
} }
@ -2424,9 +2420,10 @@ struct mmq_args {
int64_t ne0; int64_t ne0;
}; };
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y, const int cc) { template<ggml_type type>
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y); const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y);
const int mmq_tile_x_k = mmq_get_mma_tile_x_k_host(type); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
@ -2441,7 +2438,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y, cc); const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@ -2512,7 +2509,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) { for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
const int granularity = mmq_get_granularity_host(mmq_x, cc); const int granularity = mmq_get_granularity_host(mmq_x, cc);
if (mmq_x % granularity != 0 || mmq_get_shmem(type, mmq_x, mmq_y, cc) > smpbo) { if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
continue; continue;
} }