iq3_s_mult_shuffle: mult + shuffle based codebook

This commit is contained in:
Iwan Kawrakow 2024-03-04 19:43:22 +02:00
parent b48bf8b411
commit b587482287
2 changed files with 81 additions and 46 deletions

View file

@ -2375,8 +2375,12 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
// Better (lower PPL), but requires more bit twidling, so slower
#define IQ3S_MULTIPLIER 190842953LL
#else
#define IQ3S_MULTIPLIER 898886
#define IQ3S_MULTIPLIER 72968561ULL
//#define IQ3S_MULTIPLIER 540201
//#define IQ3S_MULTIPLIER 1378231
//#define IQ3S_MULTIPLIER 898886
//#define IQ3S_MULTIPLIER 842866
static const __device__ uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
#endif
template<typename dst_t>
@ -2400,32 +2404,36 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
#else
aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
//aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
//aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
aux32[0] = (((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
aux32[1] = (((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
#endif
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
#ifdef IQ3S_SLOW_MULT
aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
#endif
uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201);
aux32[0] = __vsub4(aux32[0] ^ signs0, signs0);
aux32[1] = __vsub4(aux32[1] ^ signs1, signs1);
for (int j = 0; j < 8; ++j) {
y[j] = d * grid[j];
}
#else
//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
//#ifdef IQ3S_SLOW_MULT
// aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
// aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
//#endif
// uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
// uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201);
// aux32[0] = __vsub4(aux32[0] ^ signs0, signs0);
// aux32[1] = __vsub4(aux32[1] ^ signs1, signs1);
// for (int j = 0; j < 8; ++j) {
// y[j] = d * grid[j];
// }
//#else
#ifdef IQ3S_SLOW_MULT
for (int j = 0; j < 8; ++j) {
y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
#else
//static const uint8_t k_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 15};
for (int j = 0; j < 8; ++j) {
y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
//y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
#endif
#endif
//#endif
#else
assert(false);
#endif
@ -5225,7 +5233,6 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
#endif
}
// TODO: don't use lookup table for signs
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
@ -5233,6 +5240,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
uint32_t aux32[2];
uint8_t * aux8 = (uint8_t *)aux32;
const int ib32 = iqs;
const uint8_t * qs = bq2->qs + 8*ib32;
@ -5249,8 +5257,11 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
#else
aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
//aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
//aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
aux32[0] = (((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
aux32[1] = (((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
for (int j = 0; j < 8; ++j) aux8[j] = iq3s_values[aux8[j]];
#endif
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);