From e62fab8a49a6550f18ae8cf0e1948b80cdf05ff6 Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Thu, 18 Jul 2024 10:35:12 +0800 Subject: [PATCH] add q4_0 dequant --- ggml/src/ggml-sycl/convert.cpp | 45 ++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 39c28753c..9beefce98 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -110,18 +110,49 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k, template static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k, dpct::queue_ptr stream) { - const int nb32 = k / 32; - const int nb = (k + 255) / 256; + { dpct::has_capability_or_fail(stream->get_device(), {sycl::aspect::fp16}); +#if WARP_SIZE==16 + int constexpr WARP_K = WARP_SIZE * QK4_0; + const int n_warp = (k + WARP_K - 1) / WARP_K; + GGML_ASSERT(k % 2 == 0); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * + sycl::range<3>(1, 1, WARP_SIZE), + sycl::range<3>(1, 1, WARP_SIZE)), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { + const int warp_id = item_ct1.get_group(2); + const int lane_id = item_ct1.get_local_id(2); + const int lane_ib = warp_id * WARP_SIZE + lane_id; + if (lane_ib >= k) { + return; + } + dst_t* y_ptr = y + lane_ib * QK4_0; + + auto q_ptr = (const uint8_t*)vx + lane_ib * QK4_0 / 2; + auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib; + + const float d = float(*s_ptr); + auto qs = *(sycl::vec*)q_ptr; + + for (int l = 0; l < QK4_0 / 2; ++l) { + int vq = qs[l]; + y[l + 0] = d * ((vq & 0xF) - 8); + y[l + 16] = d * ((vq >> 4) - 8); + } + }); +#else + const int nb32 = k / 32; + const int nb = (k + 255) / 256; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * - sycl::range<3>(1, 1, 32), - sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { - dequantize_block_q4_0(vx, y, nb32, item_ct1); - }); + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_block_q4_0(vx, y, nb32, item_ct1); + }); +#endif } }