xor hack
This commit is contained in:
parent
2e6cd4b025
commit
fbf5588abc
1 changed files with 16 additions and 6 deletions
22
ggml-cuda.cu
22
ggml-cuda.cu
|
@ -207,8 +207,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
|||
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
__shared__ float tmp[block_size]; // separate sum for each thread
|
||||
tmp[tid] = 0;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = 0; i < ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
|
@ -221,20 +221,30 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
|||
dequantize_kernel(vx, ib, iqs, v0, v1);
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||
tmp += v0 * y[iybs + iqs + 0];
|
||||
tmp += v1 * y[iybs + iqs + y_offset];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
__syncthreads();
|
||||
#ifdef GGML_USE_HIPBLAS
|
||||
__shared__ float tmpa[block_size];
|
||||
tmpa[tid] = tmp;
|
||||
for (int s=block_size/2; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
tmpa[tid] += tmpa[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
tmp = tmpa[0]; // now full sum
|
||||
#else
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
#endif
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp[0];
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue