sycl: fix convert overflow
Signed-off-by: zhentaoyu <zhentao.yu@intel.com>
This commit is contained in:
parent
d36d6547aa
commit
9a9f7c959c
3 changed files with 153 additions and 146 deletions
|
@ -3,19 +3,19 @@
|
||||||
#include "presets.hpp"
|
#include "presets.hpp"
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
|
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||||
item_ct1.get_local_id(2));
|
item_ct1.get_local_id(2));
|
||||||
|
|
||||||
if (i >= k) {
|
if (i >= k) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ib = i/qk; // block index
|
const int64_t ib = i/qk; // block index
|
||||||
const int iqs = (i%qk)/qr; // quant index
|
const int64_t iqs = (i%qk)/qr; // quant index
|
||||||
const int iybs = i - i%qk; // y block start index
|
const int64_t iybs = i - i%qk; // y block start index
|
||||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
dfloat2 v;
|
dfloat2 v;
|
||||||
|
@ -27,7 +27,7 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
static void dequantize_block_sycl(const void *__restrict__ vx,
|
static void dequantize_block_sycl(const void *__restrict__ vx,
|
||||||
dst_t *__restrict__ y, const int k,
|
dst_t *__restrict__ y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
|
const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
|
||||||
{
|
{
|
||||||
|
@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb32 = k / 32;
|
const int64_t nb32 = k / 32;
|
||||||
const int nb = (k + 255) / 256;
|
const int64_t nb = (k + 255) / 256;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb32 = k / 32;
|
const int64_t nb32 = k / 32;
|
||||||
const int nb = (k + 255) / 256;
|
const int64_t nb = (k + 255) / 256;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
|
@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
|
||||||
|
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = (k + QK_K - 1) / QK_K;
|
const int64_t nb = (k + QK_K - 1) / QK_K;
|
||||||
#if QK_K == 64
|
#if QK_K == 64
|
||||||
dequantize_row_iq4_nl_sycl(vx, y, k, stream);
|
dequantize_row_iq4_nl_sycl(vx, y, k, stream);
|
||||||
#else
|
#else
|
||||||
|
@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
|
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int nb = (k + QK_K - 1) / QK_K;
|
const int64_t nb = (k + QK_K - 1) / QK_K;
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
@ -418,34 +418,41 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename src_t, typename dst_t>
|
template <typename src_t, typename dst_t>
|
||||||
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
|
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
const int64_t work_group_size = item_ct1.get_local_range(2);
|
||||||
item_ct1.get_local_id(2);
|
const int64_t global_id = item_ct1.get_local_id(2) + item_ct1.get_group(2) * work_group_size;
|
||||||
|
|
||||||
if (i >= k) {
|
// make each work-item deal with more elements since sycl global range can not exceed max int
|
||||||
return;
|
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
|
||||||
|
const src_t * x = (src_t *) vx;
|
||||||
|
|
||||||
|
y[i] = x[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
const src_t * x = (src_t *) vx;
|
|
||||||
|
|
||||||
y[i] = x[i];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename src_t, typename dst_t>
|
template <typename src_t, typename dst_t>
|
||||||
static void convert_unary_sycl(const void *__restrict__ vx,
|
static void convert_unary_sycl(const void *__restrict__ vx,
|
||||||
dst_t *__restrict__ y, const int k,
|
dst_t *__restrict__ y, const int64_t k,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
|
const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
|
||||||
|
|
||||||
|
// decrease global range when it exceeds the max int
|
||||||
|
int local_size = SYCL_DEQUANTIZE_BLOCK_SIZE;
|
||||||
|
const int64_t max_range = std::numeric_limits<int>::max();
|
||||||
|
int64_t global_range = num_blocks * local_size;
|
||||||
|
while(global_range > max_range) {
|
||||||
|
local_size /= 2;
|
||||||
|
global_range = num_blocks * local_size;
|
||||||
|
}
|
||||||
|
sycl::range<3> block_nums(1, 1, num_blocks);
|
||||||
|
sycl::range<3> local_range(1, 1, local_size);
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
stream->parallel_for(
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(
|
sycl::nd_range<3>(block_nums * local_range, local_range),
|
||||||
sycl::range<3>(1, 1, num_blocks) *
|
|
||||||
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
|
||||||
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
convert_unary<src_t>(vx, y, k, item_ct1);
|
convert_unary<src_t>(vx, y, k, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
|
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
|
||||||
int k, dpct::queue_ptr stream);
|
int64_t k, dpct::queue_ptr stream);
|
||||||
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
||||||
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
|
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
|
||||||
|
|
||||||
|
|
|
@ -141,13 +141,13 @@ template<typename dst_t>
|
||||||
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
|
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const int il = tid/8;
|
const int64_t il = tid/8;
|
||||||
const int ir = tid%8;
|
const int64_t ir = tid%8;
|
||||||
const int ib = 8*i + ir;
|
const int64_t ib = 8*i + ir;
|
||||||
if (ib >= nb32) {
|
if (ib >= nb32) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -171,13 +171,13 @@ template<typename dst_t>
|
||||||
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
|
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const int il = tid/8;
|
const int64_t il = tid/8;
|
||||||
const int ir = tid%8;
|
const int64_t ir = tid%8;
|
||||||
const int ib = 8*i + ir;
|
const int64_t ib = 8*i + ir;
|
||||||
if (ib >= nb32) {
|
if (ib >= nb32) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -203,14 +203,14 @@ template<typename dst_t>
|
||||||
static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_q2_K * x = (const block_q2_K *) vx;
|
const block_q2_K * x = (const block_q2_K *) vx;
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int n = tid/32;
|
const int64_t n = tid/32;
|
||||||
const int l = tid - 32*n;
|
const int64_t l = tid - 32*n;
|
||||||
const int is = 8*n + l/16;
|
const int64_t is = 8*n + l/16;
|
||||||
|
|
||||||
const uint8_t q = x[i].qs[32*n + l];
|
const uint8_t q = x[i].qs[32*n + l];
|
||||||
dst_t * y = yy + i*QK_K + 128*n;
|
dst_t * y = yy + i*QK_K + 128*n;
|
||||||
|
@ -222,8 +222,8 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
|
||||||
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
||||||
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
||||||
#else
|
#else
|
||||||
const int is = tid/16; // 0 or 1
|
const int64_t is = tid/16; // 0 or 1
|
||||||
const int il = tid%16; // 0...15
|
const int64_t il = tid%16; // 0...15
|
||||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
|
|
||||||
|
@ -239,19 +239,19 @@ template<typename dst_t>
|
||||||
static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_q3_K * x = (const block_q3_K *) vx;
|
const block_q3_K * x = (const block_q3_K *) vx;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int r = item_ct1.get_local_id(2) / 4;
|
const int64_t r = item_ct1.get_local_id(2) / 4;
|
||||||
const int tid = r/2;
|
const int64_t tid = r/2;
|
||||||
const int is0 = r%2;
|
const int64_t is0 = r%2;
|
||||||
const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
|
const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
|
||||||
const int n = tid / 4;
|
const int64_t n = tid / 4;
|
||||||
const int j = tid - 4*n;
|
const int64_t j = tid - 4*n;
|
||||||
|
|
||||||
uint8_t m = 1 << (4*n + j);
|
uint8_t m = 1 << (4*n + j);
|
||||||
int is = 8*n + 2*j + is0;
|
int64_t is = 8*n + 2*j + is0;
|
||||||
int shift = 2*j;
|
int shift = 2*j;
|
||||||
|
|
||||||
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
||||||
|
@ -267,11 +267,11 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
|
||||||
|
|
||||||
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||||
#else
|
#else
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const int is = tid/16; // 0 or 1
|
const int64_t is = tid/16; // 0 or 1
|
||||||
const int il = tid%16; // 0...15
|
const int64_t il = tid%16; // 0...15
|
||||||
const int im = il/8; // 0...1
|
const int64_t im = il/8; // 0...1
|
||||||
const int in = il%8; // 0...7
|
const int64_t in = il%8; // 0...7
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||||
|
|
||||||
|
@ -307,15 +307,15 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
||||||
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
|
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
|
||||||
const block_q4_K * x = (const block_q4_K *) vx;
|
const block_q4_K * x = (const block_q4_K *) vx;
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const int il = tid/8;
|
const int64_t il = tid/8;
|
||||||
const int ir = tid%8;
|
const int64_t ir = tid%8;
|
||||||
const int is = 2*il;
|
const int64_t is = 2*il;
|
||||||
const int n = 4;
|
const int64_t n = 4;
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||||
|
|
||||||
|
@ -341,7 +341,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
||||||
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
|
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const uint8_t * q = x[i].qs;
|
const uint8_t * q = x[i].qs;
|
||||||
dst_t * y = yy + i*QK_K;
|
dst_t * y = yy + i*QK_K;
|
||||||
const float d = (float)x[i].dm[0];
|
const float d = (float)x[i].dm[0];
|
||||||
|
@ -356,14 +356,14 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const block_q5_K * x = (const block_q5_K *) vx;
|
const block_q5_K * x = (const block_q5_K *) vx;
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
// assume 64 threads - this is very slightly better than the one below
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const int il = tid/16; // il is in 0...3
|
const int64_t il = tid/16; // il is in 0...3
|
||||||
const int ir = tid%16; // ir is in 0...15
|
const int64_t ir = tid%16; // ir is in 0...15
|
||||||
const int is = 2*il; // is is in 0...6
|
const int64_t is = 2*il; // is is in 0...6
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||||
|
|
||||||
|
@ -386,11 +386,11 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
|
||||||
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
||||||
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
||||||
#else
|
#else
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const uint8_t q = x[i].qs[tid];
|
const uint8_t q = x[i].qs[tid];
|
||||||
const int im = tid/8; // 0...3
|
const int64_t im = tid/8; // 0...3
|
||||||
const int in = tid%8; // 0...7
|
const int64_t in = tid%8; // 0...7
|
||||||
const int is = tid/16; // 0 or 1
|
const int64_t is = tid/16; // 0 or 1
|
||||||
const uint8_t h = x[i].qh[in] >> im;
|
const uint8_t h = x[i].qh[in] >> im;
|
||||||
const float d = x[i].d;
|
const float d = x[i].d;
|
||||||
dst_t * y = yy + i*QK_K + tid;
|
dst_t * y = yy + i*QK_K + tid;
|
||||||
|
@ -404,14 +404,14 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const block_q6_K * x = (const block_q6_K *) vx;
|
const block_q6_K * x = (const block_q6_K *) vx;
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
|
|
||||||
// assume 64 threads - this is very slightly better than the one below
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const int ip = tid/32; // ip is 0 or 1
|
const int64_t ip = tid/32; // ip is 0 or 1
|
||||||
const int il = tid - 32*ip; // 0...32
|
const int64_t il = tid - 32*ip; // 0...32
|
||||||
const int is = 8*ip + il/16;
|
const int64_t is = 8*ip + il/16;
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 128*ip + il;
|
dst_t * y = yy + i*QK_K + 128*ip + il;
|
||||||
|
|
||||||
|
@ -428,9 +428,9 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
||||||
#else
|
#else
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const int ip = tid/16; // 0 or 1
|
const int64_t ip = tid/16; // 0 or 1
|
||||||
const int il = tid - 16*ip; // 0...15
|
const int64_t il = tid - 16*ip; // 0...15
|
||||||
|
|
||||||
dst_t * y = yy + i*QK_K + 16*ip + il;
|
dst_t * y = yy + i*QK_K + 16*ip + il;
|
||||||
|
|
||||||
|
@ -452,13 +452,13 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
|
||||||
const uint8_t *ksigns_iq2xs_ptr,
|
const uint8_t *ksigns_iq2xs_ptr,
|
||||||
const uint8_t *kmask_iq2xs_ptr) {
|
const uint8_t *kmask_iq2xs_ptr) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
const uint8_t * aux8 = (const uint8_t *)q2;
|
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||||
|
@ -480,13 +480,13 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
|
||||||
const uint8_t *ksigns_iq2xs,
|
const uint8_t *ksigns_iq2xs,
|
||||||
const uint8_t *kmask_iq2xs) {
|
const uint8_t *kmask_iq2xs) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||||
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
||||||
|
@ -504,13 +504,13 @@ __dpct_inline__ static void
|
||||||
dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_iq2_s * x = (const block_iq2_s *) vx;
|
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||||
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||||
|
@ -532,13 +532,13 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
|
||||||
const uint8_t *ksigns_iq2xs,
|
const uint8_t *ksigns_iq2xs,
|
||||||
const uint8_t *kmask_iq2xs) {
|
const uint8_t *kmask_iq2xs) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * q3 = x[i].qs + 8*ib;
|
const uint8_t * q3 = x[i].qs + 8*ib;
|
||||||
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
||||||
|
@ -563,13 +563,13 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1,
|
const sycl::nd_item<3> &item_ct1,
|
||||||
const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
|
const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_iq3_s * x = (const block_iq3_s *) vx;
|
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * qs = x[i].qs + 8*ib;
|
const uint8_t * qs = x[i].qs + 8*ib;
|
||||||
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
||||||
|
@ -593,13 +593,13 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1,
|
const sycl::nd_item<3> &item_ct1,
|
||||||
const uint32_t *iq1s_grid_gpu) {
|
const uint32_t *iq1s_grid_gpu) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_iq1_s * x = (const block_iq1_s *) vx;
|
const block_iq1_s * x = (const block_iq1_s *) vx;
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
||||||
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
||||||
|
@ -623,13 +623,13 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1,
|
const sycl::nd_item<3> &item_ct1,
|
||||||
const uint32_t *iq1s_grid_gpu) {
|
const uint32_t *iq1s_grid_gpu) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_iq1_m * x = (const block_iq1_m *) vx;
|
const block_iq1_m * x = (const block_iq1_m *) vx;
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||||
iq1m_scale_t scale;
|
iq1m_scale_t scale;
|
||||||
|
@ -656,12 +656,12 @@ __dpct_inline__ static void
|
||||||
dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
|
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
const uint8_t * q4 = x[ib].qs + 4*il;
|
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||||
const float d = (float)x[ib].d;
|
const float d = (float)x[ib].d;
|
||||||
|
@ -678,12 +678,12 @@ template <typename dst_t>
|
||||||
__dpct_inline__ static void
|
__dpct_inline__ static void
|
||||||
dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||||
const sycl::nd_item<3> &item_ct1) {
|
const sycl::nd_item<3> &item_ct1) {
|
||||||
const int i = item_ct1.get_group(2);
|
const int64_t i = item_ct1.get_group(2);
|
||||||
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||||
|
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int64_t tid = item_ct1.get_local_id(2);
|
||||||
const int il = tid/8; // 0...3
|
const int64_t il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int64_t ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||||
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue