warp size fixes
This commit is contained in:
parent
33091a9bd3
commit
5d6eb72164
1 changed files with 6 additions and 2 deletions
|
@ -182,7 +182,11 @@ typedef struct {
|
||||||
} block_q6_k;
|
} block_q6_k;
|
||||||
static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_k block size/padding");
|
static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_k block size/padding");
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
|
#define WARP_SIZE warpSize
|
||||||
|
#else
|
||||||
#define WARP_SIZE 32
|
#define WARP_SIZE 32
|
||||||
|
#endif
|
||||||
|
|
||||||
#define CUDA_MUL_BLOCK_SIZE 256
|
#define CUDA_MUL_BLOCK_SIZE 256
|
||||||
|
|
||||||
|
@ -679,8 +683,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, WARP_SIZE);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue