add q4_0 dequant
This commit is contained in:
parent
0b8565d979
commit
e62fab8a49
1 changed files with 38 additions and 7 deletions
|
@ -110,18 +110,49 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
|
|||
template <typename dst_t>
|
||||
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb32 = k / 32;
|
||||
const int nb = (k + 255) / 256;
|
||||
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
#if WARP_SIZE==16
|
||||
int constexpr WARP_K = WARP_SIZE * QK4_0;
|
||||
const int n_warp = (k + WARP_K - 1) / WARP_K;
|
||||
GGML_ASSERT(k % 2 == 0);
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
||||
sycl::range<3>(1, 1, WARP_SIZE),
|
||||
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const int warp_id = item_ct1.get_group(2);
|
||||
const int lane_id = item_ct1.get_local_id(2);
|
||||
const int lane_ib = warp_id * WARP_SIZE + lane_id;
|
||||
if (lane_ib >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst_t* y_ptr = y + lane_ib * QK4_0;
|
||||
|
||||
auto q_ptr = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
|
||||
auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;
|
||||
|
||||
const float d = float(*s_ptr);
|
||||
auto qs = *(sycl::vec<uint8_t, QK4_0 / 2>*)q_ptr;
|
||||
|
||||
for (int l = 0; l < QK4_0 / 2; ++l) {
|
||||
int vq = qs[l];
|
||||
y[l + 0] = d * ((vq & 0xF) - 8);
|
||||
y[l + 16] = d * ((vq >> 4) - 8);
|
||||
}
|
||||
});
|
||||
#else
|
||||
const int nb32 = k / 32;
|
||||
const int nb = (k + 255) / 256;
|
||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q4_0(vx, y, nb32, item_ct1);
|
||||
});
|
||||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
dequantize_block_q4_0(vx, y, nb32, item_ct1);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue