add q4_0 dequant

This commit is contained in:
luoyu-intel 2024-07-18 10:35:12 +08:00
parent 0b8565d979
commit e62fab8a49

View file

@ -110,18 +110,49 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
template <typename dst_t> template <typename dst_t>
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k, static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
{ {
dpct::has_capability_or_fail(stream->get_device(), dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16}); {sycl::aspect::fp16});
#if WARP_SIZE==16
int constexpr WARP_K = WARP_SIZE * QK4_0;
const int n_warp = (k + WARP_K - 1) / WARP_K;
GGML_ASSERT(k % 2 == 0);
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
sycl::range<3>(1, 1, WARP_SIZE),
sycl::range<3>(1, 1, WARP_SIZE)),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
const int warp_id = item_ct1.get_group(2);
const int lane_id = item_ct1.get_local_id(2);
const int lane_ib = warp_id * WARP_SIZE + lane_id;
if (lane_ib >= k) {
return;
}
dst_t* y_ptr = y + lane_ib * QK4_0;
auto q_ptr = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;
const float d = float(*s_ptr);
auto qs = *(sycl::vec<uint8_t, QK4_0 / 2>*)q_ptr;
for (int l = 0; l < QK4_0 / 2; ++l) {
int vq = qs[l];
y[l + 0] = d * ((vq & 0xF) - 8);
y[l + 16] = d * ((vq >> 4) - 8);
}
});
#else
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)), sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) { [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_block_q4_0(vx, y, nb32, item_ct1); dequantize_block_q4_0(vx, y, nb32, item_ct1);
}); });
#endif
} }
} }