warp size fixes
This commit is contained in:
parent
33091a9bd3
commit
5d6eb72164
1 changed files with 6 additions and 2 deletions
|
@ -182,7 +182,11 @@ typedef struct {
|
|||
} block_q6_k;
|
||||
static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_k block size/padding");
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#define WARP_SIZE warpSize
|
||||
#else
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
#define CUDA_MUL_BLOCK_SIZE 256
|
||||
|
||||
|
@ -679,8 +683,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
|||
// sum up partial sums and write back result
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, WARP_SIZE);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue