warp size fixes

This commit is contained in:
Henri Vasserman 2023-06-06 18:32:41 +03:00
parent 33091a9bd3
commit 5d6eb72164
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986

View file

@ -182,7 +182,11 @@ typedef struct {
} block_q6_k; } block_q6_k;
static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_k block size/padding"); static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_k block size/padding");
#if defined(GGML_USE_HIPBLAS)
#define WARP_SIZE warpSize
#else
#define WARP_SIZE 32 #define WARP_SIZE 32
#endif
#define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256
@ -679,8 +683,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
// sum up partial sums and write back result // sum up partial sums and write back result
__syncthreads(); __syncthreads();
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) { for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); tmp += __shfl_xor_sync(0xffffffff, tmp, mask, WARP_SIZE);
} }
if (tid == 0) { if (tid == 0) {