From fbf5588abc773fcfd63e2bb2647a9a1ae67a4cd8 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 19 May 2023 12:59:37 +0200 Subject: [PATCH] xor hack --- ggml-cuda.cu | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 35d2e457c..1a64ff6b1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -207,8 +207,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, const int y_offset = qr == 1 ? 1 : qk/2; - __shared__ float tmp[block_size]; // separate sum for each thread - tmp[tid] = 0; + + float tmp = 0; // partial sum for thread in warp for (int i = 0; i < ncols/block_size; i += 2) { const int col = i*block_size + 2*tid; @@ -221,20 +221,30 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, dequantize_kernel(vx, ib, iqs, v0, v1); // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; + tmp += v0 * y[iybs + iqs + 0]; + tmp += v1 * y[iybs + iqs + y_offset]; } // sum up partial sums and write back result __syncthreads(); +#ifdef GGML_USE_HIPBLAS + __shared__ float tmpa[block_size]; + tmpa[tid] = tmp; for (int s=block_size/2; s>0; s>>=1) { if (tid < s) { - tmp[tid] += tmp[tid + s]; + tmpa[tid] += tmpa[tid + s]; } __syncthreads(); } + tmp = tmpa[0]; // now full sum +#else + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } +#endif + if (tid == 0) { - dst[row] = tmp[0]; + dst[row] = tmp; } }