only a single get_mma_tile_x_k function
This commit is contained in:
parent
db6dae797b
commit
cab5981951
1 changed files with 6 additions and 9 deletions
|
@ -124,11 +124,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
||||||
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
|
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
|
||||||
0
|
0
|
||||||
|
|
||||||
static int mmq_get_mma_tile_x_k_host(const ggml_type type) {
|
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
||||||
MMQ_MMA_GET_TILE_X_K_BODY;
|
|
||||||
}
|
|
||||||
|
|
||||||
static constexpr __device__ int mmq_get_mma_tile_x_k_device(ggml_type type) {
|
|
||||||
MMQ_MMA_GET_TILE_X_K_BODY;
|
MMQ_MMA_GET_TILE_X_K_BODY;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2424,9 +2420,10 @@ struct mmq_args {
|
||||||
int64_t ne0;
|
int64_t ne0;
|
||||||
};
|
};
|
||||||
|
|
||||||
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y, const int cc) {
|
template<ggml_type type>
|
||||||
|
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
|
||||||
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y);
|
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y);
|
||||||
const int mmq_tile_x_k = mmq_get_mma_tile_x_k_host(type);
|
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
||||||
const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
||||||
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
|
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
|
||||||
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
|
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
|
||||||
|
@ -2441,7 +2438,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
||||||
|
|
||||||
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
|
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
|
||||||
|
|
||||||
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y, cc);
|
const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
||||||
|
@ -2512,7 +2509,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
|
||||||
for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
|
for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
|
||||||
const int granularity = mmq_get_granularity_host(mmq_x, cc);
|
const int granularity = mmq_get_granularity_host(mmq_x, cc);
|
||||||
|
|
||||||
if (mmq_x % granularity != 0 || mmq_get_shmem(type, mmq_x, mmq_y, cc) > smpbo) {
|
if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue