This commit is contained in:
JohannesGaessler 2023-05-19 12:59:37 +02:00
parent 2e6cd4b025
commit fbf5588abc

View file

@ -207,8 +207,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
__shared__ float tmp[block_size]; // separate sum for each thread
tmp[tid] = 0; float tmp = 0; // partial sum for thread in warp
for (int i = 0; i < ncols/block_size; i += 2) { for (int i = 0; i < ncols/block_size; i += 2) {
const int col = i*block_size + 2*tid; const int col = i*block_size + 2*tid;
@ -221,20 +221,30 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
dequantize_kernel(vx, ib, iqs, v0, v1); dequantize_kernel(vx, ib, iqs, v0, v1);
// matrix multiplication // matrix multiplication
tmp[tid] += v0 * y[iybs + iqs + 0]; tmp += v0 * y[iybs + iqs + 0];
tmp[tid] += v1 * y[iybs + iqs + y_offset]; tmp += v1 * y[iybs + iqs + y_offset];
} }
// sum up partial sums and write back result // sum up partial sums and write back result
__syncthreads(); __syncthreads();
#ifdef GGML_USE_HIPBLAS
__shared__ float tmpa[block_size];
tmpa[tid] = tmp;
for (int s=block_size/2; s>0; s>>=1) { for (int s=block_size/2; s>0; s>>=1) {
if (tid < s) { if (tid < s) {
tmp[tid] += tmp[tid + s]; tmpa[tid] += tmpa[tid + s];
} }
__syncthreads(); __syncthreads();
} }
tmp = tmpa[0]; // now full sum
#else
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
#endif
if (tid == 0) { if (tid == 0) {
dst[row] = tmp[0]; dst[row] = tmp;
} }
} }