add q4_0 dequant
This commit is contained in:
parent
0b8565d979
commit
e62fab8a49
1 changed files with 38 additions and 7 deletions
|
@ -110,18 +110,49 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb32 = k / 32;
|
|
||||||
const int nb = (k + 255) / 256;
|
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
#if WARP_SIZE==16
|
||||||
|
int constexpr WARP_K = WARP_SIZE * QK4_0;
|
||||||
|
const int n_warp = (k + WARP_K - 1) / WARP_K;
|
||||||
|
GGML_ASSERT(k % 2 == 0);
|
||||||
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
||||||
|
sycl::range<3>(1, 1, WARP_SIZE),
|
||||||
|
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
const int warp_id = item_ct1.get_group(2);
|
||||||
|
const int lane_id = item_ct1.get_local_id(2);
|
||||||
|
const int lane_ib = warp_id * WARP_SIZE + lane_id;
|
||||||
|
if (lane_ib >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_t* y_ptr = y + lane_ib * QK4_0;
|
||||||
|
|
||||||
|
auto q_ptr = (const uint8_t*)vx + lane_ib * QK4_0 / 2;
|
||||||
|
auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k / 2) + lane_ib;
|
||||||
|
|
||||||
|
const float d = float(*s_ptr);
|
||||||
|
auto qs = *(sycl::vec<uint8_t, QK4_0 / 2>*)q_ptr;
|
||||||
|
|
||||||
|
for (int l = 0; l < QK4_0 / 2; ++l) {
|
||||||
|
int vq = qs[l];
|
||||||
|
y[l + 0] = d * ((vq & 0xF) - 8);
|
||||||
|
y[l + 16] = d * ((vq >> 4) - 8);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
#else
|
||||||
|
const int nb32 = k / 32;
|
||||||
|
const int nb = (k + 255) / 256;
|
||||||
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
sycl::range<3>(1, 1, 32),
|
sycl::range<3>(1, 1, 32),
|
||||||
sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_block_q4_0(vx, y, nb32, item_ct1);
|
dequantize_block_q4_0(vx, y, nb32, item_ct1);
|
||||||
});
|
});
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue