iq3_xs: make CUDA work for new version

This commit is contained in:
Iwan Kawrakow 2024-02-22 11:09:10 +02:00
parent d83fddaa3b
commit eacff4aa81

View file

@ -4822,6 +4822,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
#endif #endif
} }
// TODO: don't use lookup table for signs
static __device__ __forceinline__ float vec_dot_iq3_xs_q8_1( static __device__ __forceinline__ float vec_dot_iq3_xs_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
@ -4829,23 +4830,21 @@ static __device__ __forceinline__ float vec_dot_iq3_xs_q8_1(
const block_iq3_xs * bq2 = (const block_iq3_xs *) vbq; const block_iq3_xs * bq2 = (const block_iq3_xs *) vbq;
const int ib32 = iqs; const int ib32 = iqs;
const uint8_t * q3 = bq2->qs + 8*ib32; const uint8_t * qs = bq2->qs + 8*ib32;
const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32;
const int8_t * q8 = bq8_1[ib32].qs; const int8_t * q8 = bq8_1[ib32].qs;
uint32_t aux32 = gas[0] | (gas[1] << 16);
int sumi = 0; int sumi = 0;
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
const uint32_t * grid1 = iq3xs_grid + (q3[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
const uint32_t * grid2 = iq3xs_grid + (q3[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127)); uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]); uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]); const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
sumi = __dp4a(grid_l, *((int *)q8+0), sumi); sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
sumi = __dp4a(grid_h, *((int *)q8+1), sumi); sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
q8 += 8; q8 += 8;
aux32 >>= 7;
} }
const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f; const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
return d * sumi; return d * sumi;
#else #else
assert(false); assert(false);