iq1s_blocks16: uint32_t codebook is also better in CUDA
TG-128 is now 204 t/s up from 194 t/s. PP-512 is 5890 t/s, so significantly better than other quants
This commit is contained in:
parent
7545d69312
commit
156220f8ca
1 changed files with 21 additions and 12 deletions
33
ggml-cuda.cu
33
ggml-cuda.cu
|
@ -1722,9 +1722,22 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
||||||
const int il = tid/8; // 0...3
|
const int il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)));
|
|
||||||
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
|
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
|
||||||
for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
|
int grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
||||||
|
grid32[0] = *((const int *)(iq1s_grid + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8))));
|
||||||
|
grid32[1] = __vsub4((grid32[0] >> 4) & 0x0f0f0f0f, 0x01010101);
|
||||||
|
grid32[0] = __vsub4(grid32[0] & 0x0f0f0f0f, 0x01010101);
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
y[j] = d * q[j];
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
const uint8_t * grid = (const uint8_t *)(iq1s_grid + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)));
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
y[j+0] = d * ((grid[j] & 0xf) - 1);
|
||||||
|
y[j+4] = d * ((grid[j] >> 4) - 1);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
#else
|
#else
|
||||||
assert(false);
|
assert(false);
|
||||||
#endif
|
#endif
|
||||||
|
@ -4542,31 +4555,27 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
||||||
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
|
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
|
||||||
|
|
||||||
const int ib32 = iqs;
|
const int ib32 = iqs;
|
||||||
//uint16_t h = bq1->qh[ib32];
|
|
||||||
int sumi = 0;
|
int sumi = 0;
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
const int * q8 = (const int *)bq8_1[ib32].qs;
|
const int * q8 = (const int *)bq8_1[ib32].qs;
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
const int * grid = (const int *)(iq1s_grid + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
const int * grid = (const int *)(iq1s_grid + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
||||||
//const int * grid = (const int *)(iq1s_grid + (bq1->qs[4*ib32+l] | ((h & 7) << 8)));
|
int grid0 = __vsub4(grid[0] & 0x0f0f0f0f, 0x01010101);
|
||||||
sumi = __dp4a(q8[2*l+1], grid[1], __dp4a(q8[2*l+0], grid[0], sumi));
|
int grid1 = __vsub4((grid[0] >> 4) & 0x0f0f0f0f, 0x01010101);
|
||||||
//h >>= 3;
|
sumi = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
const int8_t * q8 = bq8_1[ib32].qs;
|
const int8_t * q8 = bq8_1[ib32].qs;
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
const int8_t * grid = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
const uint8_t * grid = (const uint8_t *)(iq1s_grid + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
||||||
//const int8_t * grid = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+l] | ((h & 7) << 8)));
|
for (int j = 0; j < 4; ++j) {
|
||||||
for (int j = 0; j < 8; ++j) {
|
sumi += q8[j] * ((grid[j] & 0xf) - 1) + q8[j+4] * ((grid[j] >> 4) - 1);
|
||||||
sumi += q8[j] * grid[j];
|
|
||||||
}
|
}
|
||||||
q8 += 8;
|
q8 += 8;
|
||||||
//h >>= 3;
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
|
const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
|
||||||
return d * sumi * (2*(bq1->qh[ib32] >> 12) + 1);
|
return d * sumi * (2*(bq1->qh[ib32] >> 12) + 1);
|
||||||
//return d * sumi * (2*h + 1);
|
|
||||||
#else
|
#else
|
||||||
assert(false);
|
assert(false);
|
||||||
return 0.f;
|
return 0.f;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue