1.5 bit: we can do even better (#5999)
* iq1_s: we can do even better Spent one of the 4 scale bits on a signs of a 0.125 shift. I.e., quants are now -1 + delta, delta, 1 + delta, where delta is +/- 0.125. CUDA works, same performance as before. PPL(LLaMA-v2-7B) is now 11.85! * iq1_s: make scalar and AVX2 work with the new version * iq1_s: make Neon work with new version. ~10% drop in performance, so will need some more work. * iq1_s: make Metal work with new version * iq1_s: very slightly faster dequantize on Metal * iq1_s: fix dequantize on the CPU --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
05b06210c9
commit
44ca159faf
4 changed files with 83 additions and 55 deletions
36
ggml-cuda.cu
36
ggml-cuda.cu
|
@ -1722,22 +1722,15 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
|||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 0xf) + 1);
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
int grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
||||
grid32[0] = *((const int *)(iq1s_grid_gpu + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8))));
|
||||
grid32[1] = __vsub4((grid32[0] >> 4) & 0x0f0f0f0f, 0x01010101);
|
||||
grid32[0] = __vsub4(grid32[0] & 0x0f0f0f0f, 0x01010101);
|
||||
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
||||
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
||||
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
||||
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
|
||||
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
||||
grid32[0] &= 0x0f0f0f0f;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
y[j] = d * q[j];
|
||||
y[j] = d * (q[j] + delta);
|
||||
}
|
||||
#else
|
||||
const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+0] = d * ((grid[j] & 0xf) - 1);
|
||||
y[j+4] = d * ((grid[j] >> 4) - 1);
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
|
@ -4560,22 +4553,25 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
|||
const int * q8 = (const int *)bq8_1[ib32].qs;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
||||
int grid0 = __vsub4(grid[0] & 0x0f0f0f0f, 0x01010101);
|
||||
int grid1 = __vsub4((grid[0] >> 4) & 0x0f0f0f0f, 0x01010101);
|
||||
int grid0 = grid[0] & 0x0f0f0f0f;
|
||||
int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
|
||||
sumi = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi));
|
||||
}
|
||||
#else
|
||||
const int8_t * q8 = bq8_1[ib32].qs;
|
||||
const int8_t * q8 = bq8_1[ib32].qs;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sumi += q8[j] * ((grid[j] & 0xf) - 1) + q8[j+4] * ((grid[j] >> 4) - 1);
|
||||
sumi += q8[j] * (grid[j] & 0xf) + q8[j+4] * (grid[j] >> 4);
|
||||
}
|
||||
q8 += 8;
|
||||
}
|
||||
#endif
|
||||
const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
|
||||
return d * sumi * (2*(bq1->qh[ib32] >> 12) + 1);
|
||||
const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
|
||||
const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
|
||||
const float d = d1q * __low2float (bq8_1[ib32].ds);
|
||||
const float m = d1q * __high2float(bq8_1[ib32].ds);
|
||||
return d * sumi + m * delta;
|
||||
#else
|
||||
assert(false);
|
||||
return 0.f;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue