diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 62677981c..e2d07c202 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -2071,15 +2071,18 @@ static __global__ void mul_mat_q_stream_k_fixup( const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1; for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) { - const int64_t kb_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp); + const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp); + const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp); - if (kb_stop % blocks_per_ne00 == 0) { + // Skip fixup tile if the MMQ CUDA block never wrote anything to it: + if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) { continue; } - const int jt = kb_stop / (blocks_per_ne00*nty); - const int it = (kb_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; + const int jt = kbc_stop / (blocks_per_ne00*nty); + const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; + // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block: if (it != blockIdx.x || jt != blockIdx.y) { continue; }