diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index bc2cc2435..16cfd1717 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,7 +26,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, - { "I2_S", LLAMA_FTYPE_MOSTLY_I2, " 2 bpw per-tensor", }, + { "I2_S", LLAMA_FTYPE_MOSTLY_I2_S, " 2 bpw per-tensor quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-quants.c b/ggml-quants.c index a4a72c847..6a825cd74 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -659,6 +659,24 @@ static inline __m128i packNibbles( __m256i bytes ) { } #endif //__loongarch_asx +void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) { + int8_t* dst = (int8_t*)y; + double min = 0.00001; + double max = min; + for (int i = 0; i < n; ++i) { + max = MAX(max, (double)fabs(x[i])); + } + float s = 127 / max; + act_scales[0] = s; + float temp; + for (int i = 0; i < n; ++i) { + temp = round(x[i] * s); + if (temp > 127) temp = 127; + if (temp < -128) temp = -128; + dst[i] = (int8_t)(temp); + } +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -3308,7 +3326,9 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { // 2 bits per weight - size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row); + UNUSED(quant_weights); + + size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); int n = nrow * n_per_row; @@ -3326,7 +3346,7 @@ size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nr q8[i] = 0; continue; } - q8[i] = src[i] * i2_scale > 0 ? 1 : 3; + q8[i] = (double)src[i] * i2_scale > 0 ? 1 : 3; } // q8 -> 0, 1, 3 @@ -3773,14 +3793,19 @@ static inline __m128i get_scale_shuffle(int i) { //====================================== I2 =============================================== -void ggml_vec_dot_i2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { +void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const uint8_t * restrict x = vx; const int8_t * restrict y = vy; + UNUSED(bs); + UNUSED(bx); + UNUSED(by); + UNUSED(nrc); + int sumi = 0; for (int i = 0; i < n / 4; i++) { - int8_t* weight = (const int8_t *)(i2_q8 + x[i]); + const int8_t* weight = (const int8_t *)(i2_q8 + x[i]); sumi += (int)y[i*4+0] * weight[0]; sumi += (int)y[i*4+1] * weight[1]; sumi += (int)y[i*4+2] * weight[2]; @@ -14431,7 +14456,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_I16: case GGML_TYPE_I32: case GGML_TYPE_I64: - case GGML_TYPE_I2: + case GGML_TYPE_I2_S: // nothing to validate break; default: diff --git a/ggml-quants.h b/ggml-quants.h index fea0b41ad..a4d0c0cec 100644 --- a/ggml-quants.h +++ b/ggml-quants.h @@ -51,6 +51,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_i8_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, float* n); // Dequantization void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -99,7 +100,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_i2_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_i2_i8_s (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml.c b/ggml.c index f8752c0f8..55aa823c8 100644 --- a/ggml.c +++ b/ggml.c @@ -569,15 +569,6 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc); static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { - [GGML_TYPE_I2] = { - .type_name = "i2", - .blck_size = 1, - .type_size = sizeof(int8_t), - .is_quantized = true, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, - }, [GGML_TYPE_I8] = { .type_name = "i8", .blck_size = 1, @@ -922,6 +913,21 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot_type = GGML_TYPE_BF16, .nrows = 1, + }, + [GGML_TYPE_I2_S] = { + .type_name = "i2_s", + .blck_size = 1, + .type_size = sizeof(int8_t), + .is_quantized = true, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_i8_s, + .vec_dot_type = GGML_TYPE_I8_S, + .nrows = 1, + }, + [GGML_TYPE_I8_S] = { + .type_name = "i8_s", + .blck_size = 1, + .type_size = sizeof(int8_t), + .is_quantized = true, } }; @@ -2630,33 +2636,6 @@ inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { *s = idx; } -inline static void ggml_vec_absmaxclamp_f32(const int n, float * s, float * x, float min) { - float max = min; - for (int i = 0; i < n; ++i) { - max = MAX(max, fabs(x[i])); - } - *s = max; -} - -inline static void ggml_vec_scaleroundclamp_f32(const int n, float * s, const float * x, float scale, float min, float max) { - for (int i = 0; i < n; ++i) { - s[i] = round(x[i] * scale); - if (s[i] > max) s[i] = max; - if (s[i] < min) s[i] = min; - s[i] /= scale; - } -} - -inline static void ggml_vec_scaleroundclamp_f32_v2(const int n, float * s, int8_t* inp, float scale, float min, float max) { - float temp; - for (int i = 0; i < n; ++i) { - temp = round(s[i] * scale); - if (temp > max) temp = max; - if (temp < min) temp = min; - inp[i] = (int8_t)(temp); - } -} - // // data types // @@ -12409,8 +12388,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - if (src0->type == 31) { - // printf("row->%ld\n", (ir0 * nb01 / 4)); + if (src0->type == GGML_TYPE_I2_S) { vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale); } else { @@ -12426,164 +12404,6 @@ static void ggml_compute_forward_mul_mat_one_chunk( } } - -static void ggml_compute_forward_bitnet_mul_mat( - const struct ggml_compute_params * params, - struct ggml_tensor * dst, - struct ggml_compute_state * state) { - - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src1 = dst->src[1]; - - int64_t t0 = ggml_perf_time_us(); - UNUSED(t0); - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const enum ggml_type type = src0->type; - const bool src1_cont = ggml_is_contiguous(src1); - - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(type)); - GGML_ASSERT(nb10 == ggml_type_size(src1->type)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - // broadcast factors - const int64_t r2 = ne12 / ne02; - const int64_t r3 = ne13 / ne03; - UNUSED(r2); - UNUSED(r3); - - // nb01 >= nb00 - src0 is not transposed - // compute by src0 rows - if (params->type == GGML_TASK_TYPE_INIT) { - if (ith != 0) { - return; - } - atomic_store(&state->shared->current_chunk, nth); - char * wdata = params->wdata; - float* act_scales = (float*) ((char *) wdata + (ne11 * ne10)); - for (int64_t i13 = 0; i13 < ne13; i13++) { - for (int64_t i12 = 0; i12 < ne12; i12++) { - for (int64_t i11 = 0; i11 < ne11; i11++) { - float rowmax = 0.00001; - ggml_vec_absmaxclamp_f32(ne10, &rowmax, (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), 0.00001); - float s = 127 / rowmax; - act_scales[i11] = s; - ggml_vec_scaleroundclamp_f32_v2(ne10, - (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13), - (int8_t*) ((char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4)), - s, -128, 127); - } - } - } - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - // atomic_store(&state->shared->current_chunk, nth); - // // char * wdata = params->wdata; - // const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, ne10); - // // printf("vec_dot_type:%d\n", vec_dot_type); - // // printf("row_size:%ld\n", row_size); - // assert(params->wsize >= ne11*ne12*ne13*row_size); - // GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // for (int64_t i13 = 0; i13 < ne13; ++i13) { - // for (int64_t i12 = 0; i12 < ne12; ++i12) { - // for (int64_t i11 = 0; i11 < ne11; ++i11) { - // quantize_row_q8_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - // wdata += row_size; - // } - // } - // } - - - return; - } - - if (params->type == GGML_TASK_TYPE_FINALIZE) { - return; - } - - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) - const int64_t nr0 = ne0; - - // This is the size of the rest of the dimensions of the result - const int64_t nr1 = ne1 * ne2 * ne3; - - // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols - int64_t num_rows_per_vec_dot = 1; - // TODO: currently the mmla kernels support only even numbered rows/cols. - // this check can be removed once they are extended to support odd numbered rows/cols too - if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { - num_rows_per_vec_dot = 1; - } - - // Now select a reasonable chunk size. - int chunk_size = 16; - - // We need to step up the size if it's small - if (nr0 == 1 || nr1 == 1) { - chunk_size = 64; - } - - // distribute the work across the inner or outer loop based on which one is larger - // The number of chunks in the 0/1 dim. - // CEIL(nr0/chunk_size) - int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size; - int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size; - - // If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread. - // Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915 - // In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that. - if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) { - // distribute the thread work across the inner or outer loop based on which one is larger - nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - } - - // The number of elements in each chunk - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - - //if (ith == 0) - // printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1); - - // The first chunk comes from our thread_id, the rest will get auto-assigned. - int current_chunk = ith; - - while (current_chunk < nchunk0 * nchunk1) { - const int64_t ith0 = current_chunk % nchunk0; - const int64_t ith1 = current_chunk / nchunk0; - - const int64_t ir0_start = dr0 * ith0; - const int64_t ir0_end = MIN(ir0_start + dr0, nr0); - - const int64_t ir1_start = dr1 * ith1; - const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - - ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); - - if (nth >= nchunk0 * nchunk1) { - break; - } - - current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1); - } - -} - static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, struct ggml_tensor * dst, @@ -12597,11 +12417,6 @@ static void ggml_compute_forward_mul_mat( GGML_TENSOR_BINARY_OP_LOCALS - if (src0->type == 31) { - ggml_compute_forward_bitnet_mul_mat(params, dst, state); - return; - } - const int ith = params->ith; const int nth = params->nth; @@ -12751,8 +12566,13 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; + if (src0->type == GGML_TYPE_I2_S) { + float* act_scales = (float*) ((char *) wdata + (ne11 * ne10)); + quantize_row_i8_s((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4), ne10, act_scales + i11); + } else { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } } } } @@ -14469,7 +14289,8 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: - case GGML_TYPE_I2: + case GGML_TYPE_I2_S: + case GGML_TYPE_I8_S: case GGML_TYPE_COUNT: { GGML_ASSERT(false); @@ -21727,7 +21548,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_I2: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_I2_S: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -21750,7 +21571,7 @@ size_t ggml_quantize_chunk( assert(false); } - if (type == GGML_TYPE_I2) { + if (type == GGML_TYPE_I2_S) { result = nrows * row_size / 4 + 32; } else { GGML_ASSERT(result == nrows * row_size); diff --git a/ggml.h b/ggml.h index eb9b12487..9edc84f5a 100644 --- a/ggml.h +++ b/ggml.h @@ -377,7 +377,8 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, - GGML_TYPE_I2 = 31, + GGML_TYPE_I2_S = 31, + GGML_TYPE_I8_S = 32, GGML_TYPE_COUNT, }; diff --git a/llama.cpp b/llama.cpp index 109ac4034..865011f67 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15634,7 +15634,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; - case LLAMA_FTYPE_MOSTLY_I2: default_type = GGML_TYPE_I2; break; + case LLAMA_FTYPE_MOSTLY_I2_S: default_type = GGML_TYPE_I2_S; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: diff --git a/llama.h b/llama.h index 1a225fa61..5bfab5e03 100644 --- a/llama.h +++ b/llama.h @@ -156,7 +156,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors - LLAMA_FTYPE_MOSTLY_I2 = 33, + LLAMA_FTYPE_MOSTLY_I2_S = 33, LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file };