Eliminate shared memory, faster summation

This commit is contained in:
JohannesGaessler 2023-05-11 20:20:41 +02:00
parent 8a9d7ce624
commit c46320ddf7

View file

@ -231,8 +231,7 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
const int row = blockIdx.x; const int row = blockIdx.x;
const int tid = threadIdx.x; const int tid = threadIdx.x;
__shared__ float tmp[block_size]; // separate sum for each thread float partial_sum = 0; // separate sum for each thread
tmp[tid] = 0;
for (int i = 0; i < ncols/block_size; i += 2) { for (int i = 0; i < ncols/block_size; i += 2) {
const int col = i*block_size + 2*tid; const int col = i*block_size + 2*tid;
@ -251,19 +250,17 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
const float v1 = (vi1 - 8)*d; const float v1 = (vi1 - 8)*d;
// matrix multiplication // matrix multiplication
tmp[tid] += v0 * y[col + 0]; partial_sum += v0 * y[col + 0];
tmp[tid] += v1 * y[col + 1]; partial_sum += v1 * y[col + 1];
} }
// sum up partial sums and write back result // sum up partial sums and write back result
for (int s=block_size/2; s>0; s>>=1) { #pragma unroll
if (tid < s) { for (int mask=16; mask > 0; mask >>= 1) {
tmp[tid] += tmp[tid + s]; partial_sum += __shfl_xor_sync(0xffffffff, partial_sum, mask, 32);
}
__syncthreads();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = tmp[0]; dst[row] = partial_sum;
} }
} }