iq1_s: attempt to fix SYCL
This commit is contained in:
parent
8030da7afe
commit
da5a6f05f6
1 changed files with 18 additions and 24 deletions
|
@ -4712,14 +4712,14 @@ static void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restr
|
||||||
const int il = tid/8; // 0...3
|
const int il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * qs = x[i].qs + 8*ib;
|
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
||||||
const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]);
|
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
||||||
const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]);
|
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
||||||
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
|
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
|
||||||
const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7];
|
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
||||||
for (int j = 0; j < 4; ++j) {
|
grid32[0] &= 0x0f0f0f0f;
|
||||||
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
for (int j = 0; j < 8; ++j) {
|
||||||
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
y[j] = d * (q[j] + delta);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
assert(false);
|
assert(false);
|
||||||
|
@ -7619,24 +7619,18 @@ vec_dot_iq1_s_q8_1(const void *__restrict__ vbq,
|
||||||
const uint32_t *iq1s_grid, const uint64_t *ksigns64) {
|
const uint32_t *iq1s_grid, const uint64_t *ksigns64) {
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
|
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
|
||||||
|
const int * q8 = (const int *)bq8_1[ib32].qs;
|
||||||
const int ib32 = iqs;
|
|
||||||
const uint8_t * qs = bq1->qs + 4*ib32;
|
|
||||||
const int8_t * q8 = bq8_1[ib32].qs;
|
|
||||||
int sumi = 0;
|
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
const uint32_t * grid = (const uint32_t *)(iq1s_grid + qs[l]);
|
const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
||||||
const uint32_t * signs = (const uint32_t *)(ksigns64 + (qs[l] >> 8));
|
int grid0 = grid[0] & 0x0f0f0f0f;
|
||||||
const int grid_l = dpct::vectorized_binary<sycl::uchar4>(
|
int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
|
||||||
grid[0] ^ signs[0], signs[0], std::minus<>());
|
sumi = dpct::dp4a(q8[2*l+1], grid1, dpct::dp4a(q8[2*l+0], grid0, sumi));
|
||||||
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
|
||||||
grid[1] ^ signs[1], signs[1], std::minus<>());
|
|
||||||
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
|
|
||||||
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
|
|
||||||
q8 += 8;
|
|
||||||
}
|
}
|
||||||
const float d = (float)bq1->d * bq8_1[ib32].ds[0] * 0.25f;
|
const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
|
||||||
return d * sumi;
|
const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
|
||||||
|
const float d = d1q * (float)bq8_1[ib32].ds[0];
|
||||||
|
const float m = d1q * (float)bq8_1[ib32].ds[1];
|
||||||
|
return d * sumi + m * delta;
|
||||||
#else
|
#else
|
||||||
assert(false);
|
assert(false);
|
||||||
return 0.f;
|
return 0.f;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue