fix build break by iq1s
This commit is contained in:
parent
5cdb371731
commit
59f1f6aefc
1 changed files with 15 additions and 15 deletions
|
@ -4891,7 +4891,7 @@ static void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restr
|
|||
template<typename dst_t>
|
||||
static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
const uint64_t *iq1s_grid,
|
||||
const uint32_t *iq1s_grid_gpu,
|
||||
const uint8_t *ksigns_iq2xs,
|
||||
const uint8_t *kmask_iq2xs) {
|
||||
|
||||
|
@ -4905,7 +4905,7 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
|
|||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const int i8 = 4*ib+il;
|
||||
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
|
||||
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
|
||||
const int8_t * grid = (const int8_t *)(iq1s_grid_gpu + (x[i].qs[i8] | ((h & 8) << 5)));
|
||||
const float d = (float)x[i].d * (2*(h & 7) + 1);
|
||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
|
||||
#else
|
||||
|
@ -7803,7 +7803,7 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
|
|||
static __dpct_inline__ float
|
||||
vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
|
||||
const block_q8_1 *__restrict__ bq8_1, const int &iqs,
|
||||
const uint64_t *iq1s_grid, const uint64_t *ksigns64) {
|
||||
const uint32_t *iq1s_grid_gpu, const uint64_t *ksigns64) {
|
||||
#if QK_K == 256
|
||||
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
|
||||
|
||||
|
@ -7812,10 +7812,10 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
|
|||
const uint8_t h1 = bq1->scales[2*ib32+0];
|
||||
const uint8_t h2 = bq1->scales[2*ib32+1];
|
||||
const int * q8 = (const int *)bq8_1[ib32].qs;
|
||||
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
|
||||
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
|
||||
const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
|
||||
const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
|
||||
const int * grid1 = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
|
||||
const int * grid2 = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
|
||||
const int * grid3 = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
|
||||
const int * grid4 = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
sumi1 = dpct::dp4a(q8[j+0], grid1[j], sumi1);
|
||||
sumi2 = dpct::dp4a(q8[j+2], grid2[j], sumi2);
|
||||
|
@ -8644,7 +8644,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void * __restrict__ vx, const void *
|
|||
template <int qk, int qi, typename block_q_t, int vdr>
|
||||
static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
|
||||
const sycl::nd_item<3> &item_ct1,
|
||||
const uint64_t *iq1s_grid_ptr, const uint64_t *ksigns64_ptr ) {
|
||||
const uint32_t *iq1s_grid_gpu_ptr, const uint64_t *ksigns64_ptr ) {
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
|
||||
|
@ -8672,7 +8672,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void * __restrict__ vx, const void *
|
|||
(item_ct1.get_local_id(2) %
|
||||
(qi / vdr)); // x block quant index when casting the quants to int
|
||||
|
||||
tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_ptr, ksigns64_ptr);
|
||||
tmp += vec_dot_iq1_s_q8_1(&x[ibx], &y[iby], iqs, iq1s_grid_gpu_ptr, ksigns64_ptr);
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
|
@ -10406,7 +10406,7 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
|||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
{
|
||||
iq1s_grid.init(*stream);
|
||||
iq1s_grid_gpu.init(*stream);
|
||||
ksigns_iq2xs.init(*stream);
|
||||
kmask_iq2xs.init(*stream);
|
||||
|
||||
|
@ -10414,7 +10414,7 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
|||
{sycl::aspect::fp16});
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
auto iq1s_grid_ptr_ct1 = iq1s_grid.get_ptr();
|
||||
auto iq1s_grid_gpu_ptr_ct1 = iq1s_grid_gpu.get_ptr();
|
||||
auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr();
|
||||
auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr();
|
||||
|
||||
|
@ -10423,7 +10423,7 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
|||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_iq1_s(
|
||||
vx, y, item_ct1, iq1s_grid_ptr_ct1,
|
||||
vx, y, item_ct1, iq1s_grid_gpu_ptr_ct1,
|
||||
ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1);
|
||||
});
|
||||
});
|
||||
|
@ -11154,11 +11154,11 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
|||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||
{
|
||||
iq1s_grid.init(*stream);
|
||||
iq1s_grid_gpu.init(*stream);
|
||||
ksigns64.init(*stream);
|
||||
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
auto iq1s_grid_ptr_ct1 = iq1s_grid.get_ptr();
|
||||
auto iq1s_grid_gpu_ptr_ct1 = iq1s_grid_gpu.get_ptr();
|
||||
auto ksigns64_ptr_ct1 = ksigns64.get_ptr();
|
||||
|
||||
cgh.parallel_for(
|
||||
|
@ -11167,7 +11167,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
|||
[[intel::reqd_sub_group_size(32)]] {
|
||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||
vx, vy, dst, ncols, nrows, item_ct1,
|
||||
iq1s_grid_ptr_ct1, ksigns64_ptr_ct1);
|
||||
iq1s_grid_gpu_ptr_ct1, ksigns64_ptr_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue