rectified dmmv quant fix
This commit is contained in:
parent
da18950038
commit
8cdbe11344
1 changed files with 2 additions and 2 deletions
|
@ -47,7 +47,6 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
|
||||||
|
|
||||||
for (int i = 0; i < ncols; i += iter_stride) {
|
for (int i = 0; i < ncols; i += iter_stride) {
|
||||||
const int col = i + vals_per_iter*tid;
|
const int col = i + vals_per_iter*tid;
|
||||||
if (col >= ncols) break;
|
|
||||||
const int ib = (row*ncols + col)/qk; // x block index
|
const int ib = (row*ncols + col)/qk; // x block index
|
||||||
const int iqs = (col%qk)/qr; // x quant index
|
const int iqs = (col%qk)/qr; // x quant index
|
||||||
const int iybs = col - col%qk; // y block start index
|
const int iybs = col - col%qk; // y block start index
|
||||||
|
@ -77,8 +76,9 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
|
const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
for (int mask = mask_start; mask > 0; mask >>= 1) {
|
||||||
tmp +=
|
tmp +=
|
||||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue