This commit is contained in:
abhilash1910 2024-03-14 05:06:22 -07:00
parent 81b6139f4c
commit 0af3ed733f

View file

@ -4739,7 +4739,7 @@ static void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restr
const uint8_t *ksigns_iq2xs, const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) { const uint8_t *kmask_iq2xs) {
const int i = item_ct1.get_group(2); const int i = item_ct1.get_group(2);
const block_iq2_s * x = (const block_iq1_s *) vx; const block_iq2_s * x = (const block_iq2_s *) vx;
const int tid = item_ct1.get_local_id(2); const int tid = item_ct1.get_local_id(2);
#if QK_K == 256 #if QK_K == 256
@ -4747,8 +4747,8 @@ static void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restr
const int ib = tid%8; // 0...7 const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il; dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib; const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq1s_grid + qs[2*il+0]); const uint8_t * grid1 = (const uint8_t *)(iq2s_grid + qs[2*il+0]);
const uint8_t * grid2 = (const uint8_t *)(iq1s_grid + qs[2*il+1]); const uint8_t * grid2 = (const uint8_t *)(iq2s_grid + qs[2*il+1]);
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7]; const uint8_t signs = ksigns_iq2xs[(x[i].qh[ib] >> 3*il) & 7];
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
@ -7686,7 +7686,7 @@ vec_dot_iq2_s_q8_1(const void *__restrict__ vbq,
const block_iq2_s * bq2 = (const block_iq2_s *) vbq; const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
const int ib32 = iqs; const int ib32 = iqs;
const uint8_t * q8 = bq8_1[ib32].qs; const int8_t * q8 = bq8_1[ib32].qs;
const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32; const uint8_t * signs = bq2->qs + QK_K/8 + 4*ib32;
const uint8_t ls1 = bq2->scales[ib32] & 0xf; const uint8_t ls1 = bq2->scales[ib32] & 0xf;
const uint8_t ls2 = bq2->scales[ib32] >> 4; const uint8_t ls2 = bq2->scales[ib32] >> 4;