diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 8cc0eaf73..599281cdc 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -231,8 +231,7 @@ template static __global__ void dequantize_mul_mat_q4_0(const v const int row = blockIdx.x; const int tid = threadIdx.x; - __shared__ float tmp[block_size]; // separate sum for each thread - tmp[tid] = 0; + float partial_sum = 0; // separate sum for each thread for (int i = 0; i < ncols/block_size; i += 2) { const int col = i*block_size + 2*tid; @@ -251,19 +250,17 @@ template static __global__ void dequantize_mul_mat_q4_0(const v const float v1 = (vi1 - 8)*d; // matrix multiplication - tmp[tid] += v0 * y[col + 0]; - tmp[tid] += v1 * y[col + 1]; + partial_sum += v0 * y[col + 0]; + partial_sum += v1 * y[col + 1]; } // sum up partial sums and write back result - for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - __syncthreads(); +#pragma unroll + for (int mask=16; mask > 0; mask >>= 1) { + partial_sum += __shfl_xor_sync(0xffffffff, partial_sum, mask, 32); } if (tid == 0) { - dst[row] = tmp[0]; + dst[row] = partial_sum; } }