rectified dmmv quant fix

This commit is contained in:
OuadiElfarouki 2024-09-03 19:56:54 +01:00
parent da18950038
commit 8cdbe11344

View file

@ -47,7 +47,6 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
for (int i = 0; i < ncols; i += iter_stride) { for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid; const int col = i + vals_per_iter*tid;
if (col >= ncols) break;
const int ib = (row*ncols + col)/qk; // x block index const int ib = (row*ncols + col)/qk; // x block index
const int iqs = (col%qk)/qr; // x quant index const int iqs = (col%qk)/qr; // x quant index
const int iybs = col - col%qk; // y block start index const int iybs = col - col%qk; // y block start index
@ -77,8 +76,9 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
} }
// sum up partial sums and write back result // sum up partial sums and write back result
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { for (int mask = mask_start; mask > 0; mask >>= 1) {
tmp += tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
} }