fix deq kernel
This commit is contained in:
parent
e62fab8a49
commit
20d3b74020
1 changed files with 4 additions and 4 deletions
|
@ -125,7 +125,7 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
||||||
const int warp_id = item_ct1.get_group(2);
|
const int warp_id = item_ct1.get_group(2);
|
||||||
const int lane_id = item_ct1.get_local_id(2);
|
const int lane_id = item_ct1.get_local_id(2);
|
||||||
const int lane_ib = warp_id * WARP_SIZE + lane_id;
|
const int lane_ib = warp_id * WARP_SIZE + lane_id;
|
||||||
if (lane_ib >= k) {
|
if (lane_ib >= k / Q4_0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,11 +136,11 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
|
||||||
const float d = float(*s_ptr);
|
const float d = float(*s_ptr);
|
||||||
auto qs = *(sycl::vec<uint8_t, QK4_0 / 2>*)q_ptr;
|
auto qs = *(sycl::vec<uint8_t, QK4_0 / 2>*)q_ptr;
|
||||||
|
#pragma unroll
|
||||||
for (int l = 0; l < QK4_0 / 2; ++l) {
|
for (int l = 0; l < QK4_0 / 2; ++l) {
|
||||||
int vq = qs[l];
|
int vq = qs[l];
|
||||||
y[l + 0] = d * ((vq & 0xF) - 8);
|
y[l * 2 + 0] = d * ((vq & 0xF) - 8);
|
||||||
y[l + 16] = d * ((vq >> 4) - 8);
|
y[l * 2 + 1] = d * ((vq >> 4) - 8);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
#else
|
#else
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue