This commit is contained in:
Iwan Kawrakow 2024-03-02 17:56:16 +02:00
parent e43e81a5d7
commit 0fe9cd488f
2 changed files with 21 additions and 30 deletions

View file

@ -2371,8 +2371,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
}
//#define IQ3S_MULTIPLIER 746226
//#define IQ3S_MULTIPLIER 717154
#define IQ3S_MULTIPLIER 677595
//#define IQ3S_MULTIPLIER 677595
#define IQ3S_MULTIPLIER 190842953LL
template<typename dst_t>
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
@ -2393,8 +2393,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
aux32[0] = ((qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[1] = ((qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0x00000000) >> 1) & 0x07070707) << 1) | 0x01010101;
uint32_t signs0 = __vcmpeq4(((signs & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((signs >> 4) * 0x01010101) & 0x08040201, 0x08040201);
aux32[0] = __vsub4(aux32[0] ^ signs0, signs0);
@ -2404,9 +2404,7 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
}
#else
for (int j = 0; j < 8; ++j) {
//y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
//y[j] = d * iq3s_values[grid[j]] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
y[j] = d * (2*(((grid[j]+1)/2) & 7) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
y[j] = d * (2*((grid[j]-1)/2) + 1) * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
#endif
#else
@ -5216,7 +5214,6 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
uint32_t aux32[2];
const uint8_t * grid = (const uint8_t *)aux32;
const int ib32 = iqs;
const uint8_t * qs = bq2->qs + 8*ib32;
@ -5225,25 +5222,16 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
for (int l = 0; l < 4; ++l) {
aux32[0] = ((qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[1] = ((qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f;
aux32[0] = (((__vadd4(aux32[0], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[1] = (((__vadd4(aux32[1], 0x01010101) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[0] = (((__vmaxs4(__vsub4(aux32[0], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
aux32[1] = (((__vmaxs4(__vsub4(aux32[1], 0x01010101), 0) >> 1) & 0x07070707) << 1) | 0x01010101;
uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
const int grid_l = __vsub4(aux32[0] ^ signs0, signs0);
const int grid_h = __vsub4(aux32[1] ^ signs1, signs1);
sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
//const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256));
//const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256));
//uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201);
//uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201);
//const int grid_l = __vsub4(grid1[0] ^ signs0, signs0);
//const int grid_h = __vsub4(grid2[0] ^ signs1, signs1);
//sumi = __dp4a(grid_l, *((int *)q8+0), sumi);
//sumi = __dp4a(grid_h, *((int *)q8+1), sumi);
q8 += 8;
}
//const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f;
const float d = (float)bq2->d * (1 + 2*((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds);
return d * sumi;
#else