iq1s_blocks16: CUDA dot product

This commit is contained in:
Iwan Kawrakow 2024-03-08 16:12:48 +02:00
parent 864a5c2ce4
commit f092d049fa

View file

@ -4536,44 +4536,37 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
#endif
}
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if QK_K == 256
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
const int ib32 = iqs;
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
const uint8_t h1 = bq1->qh[2*ib32+0]; //bq1->scales[2*ib32+0];
const uint8_t h2 = bq1->qh[2*ib32+1]; //bq1->scales[2*ib32+1];
//uint16_t h = bq1->qh[ib32];
int sumi = 0;
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const int * q8 = (const int *)bq8_1[ib32].qs;
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
for (int j = 0; j < 2; ++j) {
sumi1 = __dp4a(q8[j+0], grid1[j], sumi1);
sumi2 = __dp4a(q8[j+2], grid2[j], sumi2);
sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
sumi4 = __dp4a(q8[j+6], grid4[j], sumi4);
for (int l = 0; l < 4; ++l) {
const int * grid = (const int *)(iq1s_grid + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
//const int * grid = (const int *)(iq1s_grid + (bq1->qs[4*ib32+l] | ((h & 7) << 8)));
sumi = __dp4a(q8[2*l+1], grid[1], __dp4a(q8[2*l+0], grid[0], sumi));
//h >>= 3;
}
#else
const int8_t * q8 = bq8_1[ib32].qs;
const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
for (int j = 0; j < 8; ++j) {
sumi1 += q8[j+ 0] * grid1[j];
sumi2 += q8[j+ 8] * grid2[j];
sumi3 += q8[j+16] * grid3[j];
sumi4 += q8[j+24] * grid4[j];
for (int l = 0; l < 4; ++l) {
const int8_t * grid = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
//const int8_t * grid = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+l] | ((h & 7) << 8)));
for (int j = 0; j < 8; ++j) {
sumi += q8[j] * grid[j];
}
q8 += 8;
//h >>= 3;
}
#endif
const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
return d * sumi * (2*(bq1->qh[ib32] >> 12) + 1);
//return d * sumi * (2*h + 1);
#else
assert(false);
return 0.f;