Eliminate shared memory, faster summation

This commit is contained in:
JohannesGaessler 2023-05-11 20:20:41 +02:00
parent 8a9d7ce624
commit c46320ddf7

View file

@ -231,8 +231,7 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
const int row = blockIdx.x;
const int tid = threadIdx.x;
__shared__ float tmp[block_size]; // separate sum for each thread
tmp[tid] = 0;
float partial_sum = 0; // separate sum for each thread
for (int i = 0; i < ncols/block_size; i += 2) {
const int col = i*block_size + 2*tid;
@ -251,19 +250,17 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
const float v1 = (vi1 - 8)*d;
// matrix multiplication
tmp[tid] += v0 * y[col + 0];
tmp[tid] += v1 * y[col + 1];
partial_sum += v0 * y[col + 0];
partial_sum += v1 * y[col + 1];
}
// sum up partial sums and write back result
for (int s=block_size/2; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
__syncthreads();
#pragma unroll
for (int mask=16; mask > 0; mask >>= 1) {
partial_sum += __shfl_xor_sync(0xffffffff, partial_sum, mask, 32);
}
if (tid == 0) {
dst[row] = tmp[0];
dst[row] = partial_sum;
}
}