Eliminate shared memory, faster summation
This commit is contained in:
parent
8a9d7ce624
commit
c46320ddf7
1 changed files with 7 additions and 10 deletions
17
ggml-cuda.cu
17
ggml-cuda.cu
|
@ -231,8 +231,7 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
|
|||
const int row = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
__shared__ float tmp[block_size]; // separate sum for each thread
|
||||
tmp[tid] = 0;
|
||||
float partial_sum = 0; // separate sum for each thread
|
||||
|
||||
for (int i = 0; i < ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
|
@ -251,19 +250,17 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
|
|||
const float v1 = (vi1 - 8)*d;
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[col + 0];
|
||||
tmp[tid] += v1 * y[col + 1];
|
||||
partial_sum += v0 * y[col + 0];
|
||||
partial_sum += v1 * y[col + 1];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
for (int s=block_size/2; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int mask=16; mask > 0; mask >>= 1) {
|
||||
partial_sum += __shfl_xor_sync(0xffffffff, partial_sum, mask, 32);
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp[0];
|
||||
dst[row] = partial_sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue