diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dbf35c8df..7b0786036 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -522,10 +522,12 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong #define QI3_XS (QK_K / (4*QR3_XS)) typedef struct { half d; - uint8_t qs[3*(QK_K/8)]; + uint8_t qs[QK_K/4]; uint8_t qh[QK_K/32]; + uint8_t signs[QK_K/8]; + uint8_t scales[QK_K/64]; } block_iq3_xs; -static_assert(sizeof(block_iq3_xs) == sizeof(ggml_fp16_t) + 3*(QK_K/8) + QK_K/32, "wrong iq3_xs block size/padding"); +static_assert(sizeof(block_iq3_xs) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_xs block size/padding"); #define QR1_S 8 #define QI1_S (QK_K / (4*QR1_S)) @@ -2054,20 +2056,18 @@ template static __global__ void dequantize_block_iq3_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { const int i = blockIdx.x; - const block_iq3_xs * x = (const block_iq3_xs *) vx; + const block_iq3_xs * x = (const block_iq3_xs *) vx; const int tid = threadIdx.x; #if QK_K == 256 const int il = tid/8; // 0...3 const int ib = tid%8; // 0...7 dst_t * y = yy + i*QK_K + 32*ib + 8*il; - const uint8_t * q3 = x[i].qs + 8*ib; - const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (q3[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (q3[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); - const uint32_t aux32 = gas[0] | (gas[1] << 16); - const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f; - const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; + const uint8_t * qs = x[i].qs + 8*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); + const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f; + const uint8_t signs = x[i].signs[4*ib + il]; for (int j = 0; j < 4; ++j) { y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); diff --git a/ggml-quants.c b/ggml-quants.c index 77bb9c4e0..e2549cf6d 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -3809,29 +3809,39 @@ void dequantize_row_iq3_xs(const block_iq3_xs * restrict x, float * restrict y, assert(k % QK_K == 0); const int nb = k / QK_K; - uint32_t aux32; - for (int i = 0; i < nb; i++) { const float d = GGML_FP16_TO_FP32(x[i].d); const uint8_t * qs = x[i].qs; - const uint8_t * scales_and_signs = qs + QK_K/4; const uint8_t * qh = x[i].qh; + const uint8_t * signs = x[i].signs; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - memcpy(&aux32, scales_and_signs + 4*ib32, sizeof(uint32_t)); - const float db = d * (0.5f + (aux32 >> 28)) * 0.5f; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const float db1 = d * (0.5f + (x[i].scales[ib32/2] & 0xf)) * 0.5f; + const float db2 = d * (0.5f + (x[i].scales[ib32/2] >> 4)) * 0.5f; for (int l = 0; l < 4; ++l) { - const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; - const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[ib32] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[ib32] << (7-2*l)) & 256))); + const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { - y[j+0] = db * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); - y[j+4] = db * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + y[j+0] = db1 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db1 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); } y += 8; } qs += 8; + signs += 4; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xs_grid + (qs[2*l+0] | ((qh[1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3xs_grid + (qs[2*l+1] | ((qh[1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + y[j+0] = db2 * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = db2 * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f); + } + y += 8; + } + qh += 2; + qs += 8; + signs += 4; } } } @@ -10702,7 +10712,7 @@ static int iq3_find_best_neighbour(const uint16_t * restrict neighbours, const u return grid_index; } -static void quantize_row_iq3_xs_impl(int grid_size, const float * restrict x, void * restrict vy, int n, +static void quantize_row_iq3_xxs_impl(int grid_size, const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) { const int gindex = iq3_data_index(grid_size); @@ -10921,7 +10931,7 @@ size_t quantize_iq3_xxs(const float * src, void * dst, int nrow, int n_per_row, int nblock = n_per_row/QK_K; char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { - quantize_row_iq3_xs_impl(256, src, qrow, n_per_row, quant_weights); + quantize_row_iq3_xxs_impl(256, src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq3_xxs); } @@ -10936,7 +10946,189 @@ void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int k) { void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int k) { assert(k % QK_K == 0); - quantize_row_iq3_xs_impl(256, x, y, k, NULL); + quantize_row_iq3_xxs_impl(256, x, y, k, NULL); +} + +static void quantize_row_iq3_xs_impl(int block_size, const float * restrict x, void * restrict vy, int n, + const float * restrict quant_weights) { + + const int gindex = iq3_data_index(512); + + const uint32_t * kgrid_q3xs = iq3_data[gindex].grid; + const int * kmap_q3xs = iq3_data[gindex].map; + const uint16_t * kneighbors_q3xs = iq3_data[gindex].neighbours; + + //GGML_ASSERT(quant_weights && "missing quantization weights"); + GGML_ASSERT(kgrid_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kmap_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(kneighbors_q3xs && "forgot to call ggml_quantize_init()?"); + GGML_ASSERT(n%QK_K == 0); + + const int kMaxQ = 8; + + const int nbl = n/256; + + block_iq3_xs * y = vy; + + float scales[QK_K/block_size]; + float weight[block_size]; + float xval[block_size]; + int8_t L[block_size]; + int8_t Laux[block_size]; + float waux[block_size]; + bool is_on_grid[block_size/4]; + bool is_on_grid_aux[block_size/4]; + uint8_t block_signs[block_size/8]; + + const int bs4 = block_size/4; + const int bs8 = block_size/8; + + for (int ibl = 0; ibl < nbl; ++ibl) { + + memset(&y[ibl], 0, sizeof(block_iq3_xs)); + y[ibl].d = GGML_FP32_TO_FP16(0.f); + + uint8_t * qs = y[ibl].qs; + uint8_t * qh = y[ibl].qh; + uint8_t * signs = y[ibl].signs; + + float max_scale = 0; + + const float * xbl = x + QK_K*ibl; + float sumx2 = 0; + for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i]; + float sigma2 = 2*sumx2/QK_K; + + for (int ib = 0; ib < QK_K/block_size; ++ib) { + const float * xb = xbl + block_size*ib; + if (quant_weights) { + const float * qw = quant_weights + QK_K*ibl + block_size*ib; + for (int i = 0; i < block_size; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]); + } else { + for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i]; + } + for (int i = 0; i < block_size; ++i) waux[i] = sqrtf(weight[i]); + for (int k = 0; k < bs8; ++k) { + uint8_t s = 0; + for (int i = 0; i < 8; ++i) { + if (xb[8*k + i] >= 0) xval[8*k + i] = xb[8*k + i]; + else { + xval[8*k + i] = -xb[8*k + i]; s |= (1 << i); + } + } + block_signs[k] = s; + } + float max = xval[0]; + for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]); + if (!max) { + scales[ib] = 0; + continue; + } + float best = 0; + float scale = max/(2*kMaxQ-1); + for (int is = -15; is <= 15; ++is) { + float id = (2*kMaxQ-1+is*0.2f)/max; + float this_scale = 1/id; + for (int k = 0; k < bs4; ++k) { + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + Laux[4*k+i] = MAX(0, MIN(kMaxQ-1, l)); + } + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (Laux[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + is_on_grid_aux[k] = true; + if (grid_index < 0) { + is_on_grid_aux[k] = false; + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, this_scale, Laux + 4*k); + } + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < block_size; ++i) { + float w = weight[i]; + float q = 2*Laux[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0 && sumqx*sumqx > best*sumq2) { + scale = sumqx/sumq2; best = scale*sumqx; + for (int i = 0; i < block_size; ++i) L[i] = Laux[i]; + for (int k = 0; k < bs4; ++k) is_on_grid[k] = is_on_grid_aux[k]; + } + } + int n_not_ongrid = 0; + for (int k = 0; k < bs4; ++k) if (!is_on_grid[k]) ++n_not_ongrid; + if (n_not_ongrid > 0 && scale > 0) { + float id = 1/scale; + for (int k = 0; k < bs4; ++k) { + if (is_on_grid[k]) continue; + uint16_t u = 0; + for (int i = 0; i < 4; ++i) { + int l = nearest_int(0.5f*(id*xval[4*k+i]-1)); + l = MAX(0, MIN(kMaxQ-1, l)); + u |= (l << 3*i); + } + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + const uint16_t * neighbours = kneighbors_q3xs - kmap_q3xs[u] - 1; + grid_index = iq3_find_best_neighbour(neighbours, kgrid_q3xs, xval + 4*k, waux + 4*k, scale, L + 4*k); + } + const int8_t * pg = (const int8_t *)(kgrid_q3xs + grid_index); + for (int i = 0; i < 4; ++i) L[4*k+i] = (pg[i] - 1)/2; + } + float sumqx = 0, sumq2 = 0; + for (int i = 0; i < block_size; ++i) { + float w = weight[i]; + float q = 2*L[i] + 1; + sumqx += w*xval[i]*q; + sumq2 += w*q*q; + } + if (sumq2 > 0) scale = sumqx/sumq2; + } + if (scale < 0) { + // This should never happen, but just in case, flip scale so that it is positive (we use uint's to encode the scale) + // and correspondingly flip quant signs. + scale = -scale; + for (int k = 0; k < bs8; ++k) block_signs[k] = ~block_signs[k]; + } + for (int k = 0; k < bs4; ++k) { + uint16_t u = 0; + for (int i = 0; i < 4; ++i) u |= (L[4*k+i] << 3*i); + int grid_index = kmap_q3xs[u]; + if (grid_index < 0) { + printf("Oops: found point %u not on grid:", u); + for (int i = 0; i < 4; ++i) printf(" %d", L[4*k+i]); + printf("\n"); + GGML_ASSERT(false); + } + qs[k] = grid_index & 255; + qh[(ib*bs4+k)/8] |= ((grid_index >> 8) << ((ib*bs4+k)%8)); + } + qs += bs4; + for (int k = 0; k < bs8; ++k) signs[k] = block_signs[k]; + signs += bs8; + GGML_ASSERT(scale >= 0); + scales[ib] = scale; + max_scale = MAX(max_scale, scale); + } + + if (!max_scale) { + continue; + } + + float d = max_scale/31; + y[ibl].d = GGML_FP32_TO_FP16(d); + float id = 1/d; + for (int ib = 0; ib < QK_K/block_size; ib += 2) { + int l1 = nearest_int(0.5f*(id*scales[ib+0]-1)); + l1 = MAX(0, MIN(15, l1)); + int l2 = nearest_int(0.5f*(id*scales[ib+1]-1)); + l2 = MAX(0, MIN(15, l2)); + y[ibl].scales[ib/2] = l1 | (l2 << 4); + } + + } } size_t quantize_iq3_xs(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) { @@ -10945,7 +11137,7 @@ size_t quantize_iq3_xs(const float * src, void * dst, int nrow, int n_per_row, i int nblock = n_per_row/QK_K; char * qrow = (char *)dst; for (int row = 0; row < nrow; ++row) { - quantize_row_iq3_xs_impl(512, src, qrow, n_per_row, quant_weights); + quantize_row_iq3_xs_impl(32, src, qrow, n_per_row, quant_weights); src += n_per_row; qrow += nblock*sizeof(block_iq3_xs); } @@ -10960,7 +11152,7 @@ void quantize_row_iq3_xs(const float * restrict x, void * restrict vy, int k) { void quantize_row_iq3_xs_reference(const float * restrict x, block_iq3_xs * restrict y, int k) { assert(k % QK_K == 0); - quantize_row_iq3_xs_impl(512, x, y, k, NULL); + quantize_row_iq3_xs_impl(32, x, y, k, NULL); } diff --git a/ggml-quants.h b/ggml-quants.h index 9437cd69e..da2efff0f 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -194,10 +194,12 @@ static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong // 3.3125 bpw typedef struct { ggml_fp16_t d; - uint8_t qs[3*QK_K/8]; + uint8_t qs[QK_K/4]; uint8_t qh[QK_K/32]; + uint8_t signs[QK_K/8]; + uint8_t scales[QK_K/64]; } block_iq3_xs; -static_assert(sizeof(block_iq3_xs) == sizeof(ggml_fp16_t) + 3*(QK_K/8) + QK_K/32, "wrong iq3_xs block size/padding"); +static_assert(sizeof(block_iq3_xs) == sizeof(ggml_fp16_t) + 27*(QK_K/64), "wrong iq3_xs block size/padding"); typedef struct { ggml_fp16_t d;