diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 754edf014..9a217c1c7 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1413,17 +1413,47 @@ class BitnetModel(Model): dtype = weight.dtype weight = weight.float() s = 1 / weight.abs().mean().clamp(min=1e-5) - result = (weight * s).round().clamp(-1, 1) / s - return result.type(dtype) + weight = (weight * s).round().clamp(-1, 1) / s + scale = weight.abs().max().unsqueeze(0) + weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype) + weight = torch.sign(weight).type(dtype) + return weight.type(dtype), scale.type(torch.float32) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # transform weight into 1/0/-1 (in fp32) if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", "down_proj.weight", "up_proj.weight", "gate_proj.weight", "o_proj.weight")): - data_torch = self.weight_quant(data_torch) + weight_torch, scale_torch = self.weight_quant(data_torch) - return [(self.map_tensor_name(name), data_torch)] + tensors: list[tuple[str, Tensor]] = [] + + if name.endswith("q_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q_SCALE, bid), scale_torch)) + elif name.endswith("k_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K_SCALE, bid), scale_torch)) + elif name.endswith("v_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V_SCALE, bid), scale_torch)) + elif name.endswith("o_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT_SCALE, bid), scale_torch)) + elif name.endswith("up_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SCALE, bid), scale_torch)) + elif name.endswith("down_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_SCALE, bid), scale_torch)) + elif name.endswith("gate_proj.weight"): + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch)) + tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SCALE, bid), scale_torch)) + + if len(tensors) == 0: + tensors.append((self.map_tensor_name(name), data_torch)) + + return tensors @Model.register("GrokForCausalLM") diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 16cfd1717..05df330c0 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,7 +26,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, - { "I2_S", LLAMA_FTYPE_MOSTLY_I2_S, " 2 bpw per-tensor quantization", }, + { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-common.h b/ggml-common.h index 409fcf29e..be88daa36 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -137,6 +137,13 @@ typedef sycl::half2 ggml_half2; #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP +#define QK2_2 32 +typedef struct { + ggml_half d; // delta + uint8_t qs[QK2_2 / 4]; // nibbles / quants +} block_q2_2; +static_assert(sizeof(block_q2_2) == sizeof(ggml_half) + QK2_2 / 4, "wrong q4_0 block size/padding"); + #define QK4_0 32 typedef struct { ggml_half d; // delta @@ -1022,7 +1029,7 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() -GGML_TABLE_BEGIN(uint32_t, i2s_i8s, 256) +GGML_TABLE_BEGIN(uint32_t, q22_grid, 256) 0x00000000, 0x01000000, 0x00000000, 0xff000000, 0x00010000, 0x01010000, 0x00010000, 0xff010000, 0x00000000, 0x01000000, 0x00000000, 0xff000000, diff --git a/ggml-quants.c b/ggml-quants.c index 4b5209279..aebeb0217 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -659,25 +659,44 @@ static inline __m128i packNibbles( __m256i bytes ) { } #endif //__loongarch_asx -void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) { - int8_t* dst = (int8_t*)y; - double min = 0.00001; - double max = min; - for (int i = 0; i < n; ++i) { - max = MAX(max, (double)fabs((double)x[i])); - } - float s = 127 / max; - act_scales[0] = s; - float temp; - for (int i = 0; i < n; ++i) { - temp = round((double)(x[i] * s)); - if (temp > 127) temp = 127; - if (temp < -128) temp = -128; - dst[i] = (int8_t)(temp); +void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) { + static const int qk = QK2_2; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + + const float d = 1.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < qk/4; ++j) { + int8_t x0 = (int8_t)x[i*qk + j*4 + 0]; + int8_t x1 = (int8_t)x[i*qk + j*4 + 1]; + int8_t x2 = (int8_t)x[i*qk + j*4 + 2]; + int8_t x3 = (int8_t)x[i*qk + j*4 + 3]; + + const uint8_t xi0 = x0 >= 0 ? x0 : 3; + const uint8_t xi1 = x1 >= 0 ? x1 : 3; + const uint8_t xi2 = x2 >= 0 ? x2 : 3; + const uint8_t xi3 = x3 >= 0 ? x3 : 3; + + y[i].qs[j] = 0; + y[i].qs[j] |= (xi0 << 6); + y[i].qs[j] |= (xi1 << 4); + y[i].qs[j] |= (xi2 << 2); + y[i].qs[j] |= (xi3 << 0); + } } } // reference implementation for deterministic creation of model files +void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q2_2_reference(x, y, k); +} + void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -3324,48 +3343,11 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } -size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - // 2 bits per weight - UNUSED(quant_weights); - - size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); - - int n = nrow * n_per_row; - - // f32 -> q8 - double max = 0; - for (int i = 0; i < n; ++i) { - max = MAX(max, (double)fabs((double)src[i])); - } - double i2_scale = max; - - uint8_t* q8 = (uint8_t*)dst; - for (int i=0; i 0 ? 1 : 3; - } - - // q8 -> 0, 1, 3 - // | | | - // 0, 1,-1 - - uint8_t* i2_weight = (uint8_t*)dst; - for (int i=0; ine[i] - 1)*tensor->nb[i]; } - if(tensor->type == GGML_TYPE_I2_S){ - nbytes = nbytes / 4 + 32; - } } else { nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; @@ -12289,10 +12282,6 @@ static void ggml_compute_forward_mul_mat_one_chunk( // 16 * 2, accounting for mmla kernels float tmp[32]; - // for per-tensor quant - const float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4)); - const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10)); - for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { @@ -12325,12 +12314,7 @@ static void ggml_compute_forward_mul_mat_one_chunk( //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - if (src0->type == GGML_TYPE_I2_S) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale); - } else { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); - } + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); } for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { @@ -12494,13 +12478,8 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = 0; i11 < ne11; ++i11) { - if (src0->type == GGML_TYPE_I2_S) { - float* act_scales = (float*) ((char *) wdata + (ne11 * ne10)); - quantize_row_i8_s((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4), ne10, act_scales + i11); - } else { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; - } + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; } } } @@ -14189,6 +14168,7 @@ static void ggml_compute_forward_clamp( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q2_2: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -14215,8 +14195,6 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: - case GGML_TYPE_I2_S: - case GGML_TYPE_I8_S: case GGML_TYPE_COUNT: { GGML_ASSERT(false); @@ -21340,6 +21318,7 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { + case GGML_TYPE_Q2_2: result = quantize_q2_2(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -21359,7 +21338,6 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_I2_S: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -21382,11 +21360,7 @@ size_t ggml_quantize_chunk( assert(false); } - if (type == GGML_TYPE_I2_S) { - result = nrows * row_size / 4 + 32; - } else { - GGML_ASSERT(result == nrows * row_size); - } + GGML_ASSERT(result == nrows * row_size); return result; } diff --git a/ggml.h b/ggml.h index c2e6859f5..4ec555ccb 100644 --- a/ggml.h +++ b/ggml.h @@ -377,8 +377,7 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, - GGML_TYPE_I2_S = 31, - GGML_TYPE_I8_S = 32, + GGML_TYPE_Q2_2 = 31, GGML_TYPE_COUNT, }; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 78c7290d2..7f2c10601 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -202,6 +202,13 @@ class MODEL_TENSOR(IntEnum): ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() ATTN_SUB_NORM = auto() + ATTN_Q_SCALE = auto() + ATTN_K_SCALE = auto() + ATTN_V_SCALE = auto() + ATTN_OUT_SCALE = auto() + FFN_UP_SCALE = auto() + FFN_DOWN_SCALE = auto() + FFN_GATE_SCALE = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -293,6 +300,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", + MODEL_TENSOR.ATTN_Q_SCALE: "blk.{bid}.attn_q_scale", + MODEL_TENSOR.ATTN_K_SCALE: "blk.{bid}.attn_k_scale", + MODEL_TENSOR.ATTN_V_SCALE: "blk.{bid}.attn_v_scale", + MODEL_TENSOR.ATTN_OUT_SCALE: "blk.{bid}.attn_output_scale", + MODEL_TENSOR.FFN_UP_SCALE: "blk.{bid}.ffn_up_scale", + MODEL_TENSOR.FFN_DOWN_SCALE: "blk.{bid}.ffn_down_scale", + MODEL_TENSOR.FFN_GATE_SCALE: "blk.{bid}.ffn_gate_scale", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -819,17 +833,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_V, MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, - MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_NORM, - MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_OUT, - MODEL_TENSOR.ATTN_ROT_EMBD, MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.ATTN_SUB_NORM, MODEL_TENSOR.FFN_SUB_NORM, + MODEL_TENSOR.ATTN_Q_SCALE, + MODEL_TENSOR.ATTN_K_SCALE, + MODEL_TENSOR.ATTN_V_SCALE, + MODEL_TENSOR.ATTN_OUT_SCALE, + MODEL_TENSOR.FFN_UP_SCALE, + MODEL_TENSOR.FFN_DOWN_SCALE, + MODEL_TENSOR.FFN_GATE_SCALE, ], # TODO } diff --git a/llama.cpp b/llama.cpp index 16ae07dd3..28854e8cf 100644 --- a/llama.cpp +++ b/llama.cpp @@ -498,6 +498,13 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, LLM_TENSOR_FFN_SUB_NORM, + LLM_TENSOR_ATTN_Q_SCALE, + LLM_TENSOR_ATTN_K_SCALE, + LLM_TENSOR_ATTN_V_SCALE, + LLM_TENSOR_ATTN_OUTPUT_SCALE, + LLM_TENSOR_FFN_UP_SCALE, + LLM_TENSOR_FFN_DOWN_SCALE, + LLM_TENSOR_FFN_GATE_SCALE, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1114,19 +1121,26 @@ static const std::map> LLM_TENSOR_NA { LLM_ARCH_BITNET, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + { LLM_TENSOR_ATTN_Q_SCALE, "blk.%d.attn_q_scale" }, + { LLM_TENSOR_ATTN_K_SCALE, "blk.%d.attn_q_scale" }, + { LLM_TENSOR_ATTN_V_SCALE, "blk.%d.attn_q_scale" }, + { LLM_TENSOR_ATTN_OUTPUT_SCALE, "blk.%d.attn_output_scale" }, + { LLM_TENSOR_FFN_UP_SCALE, "blk.%d.ffn_up_scale" }, + { LLM_TENSOR_FFN_DOWN_SCALE, "blk.%d.ffn_down_scale" }, + { LLM_TENSOR_FFN_GATE_SCALE, "blk.%d.ffn_gate_scale" }, }, }, { @@ -2075,6 +2089,15 @@ struct llama_layer { // long rope factors struct ggml_tensor * rope_long = nullptr; struct ggml_tensor * rope_short = nullptr; + + // bitnet scale + struct ggml_tensor * wq_scale; + struct ggml_tensor * wk_scale; + struct ggml_tensor * wv_scale; + struct ggml_tensor * wo_scale; + struct ggml_tensor * ffn_gate_scale; + struct ggml_tensor * ffn_up_scale; + struct ggml_tensor * ffn_down_scale; }; struct llama_kv_cell { @@ -6460,16 +6483,23 @@ static bool llm_load_tensors( layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_SCALE, "weight", i), {1}); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_SCALE, "weight", i), {1}); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_SCALE, "weight", i), {1}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUTPUT_SCALE, "weight", i), {1}); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SCALE, "weight", i), {1}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SCALE, "weight", i), {1}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SCALE, "weight", i), {1}); } } break; default: @@ -11545,6 +11575,7 @@ struct llm_build_context { { // compute Q and K and RoPE them struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -11553,6 +11584,7 @@ struct llm_build_context { // B1.K struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -11561,6 +11593,7 @@ struct llm_build_context { // B1.V struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -11659,6 +11692,7 @@ struct llm_build_context { ggml_build_forward_expand(graph, cur_attn); cur = ggml_mul_mat(ctx0, wo, cur_attn); + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); cb(cur, "kqv_out", il); } @@ -11681,10 +11715,12 @@ struct llm_build_context { cb(cur, "ffn_norm", il); struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale); cb(tmp, "ffn_up", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale); cb(cur, "ffn_gate", il); @@ -11701,6 +11737,7 @@ struct llm_build_context { cb(cur, "ffn_sub_norm", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); cb(cur, "ffn_down", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -15444,6 +15481,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s llama_ftype ftype = params->ftype; switch (params->ftype) { + case LLAMA_FTYPE_MOSTLY_Q2_2: default_type = GGML_TYPE_Q2_2; break; case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; @@ -15452,7 +15490,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; - case LLAMA_FTYPE_MOSTLY_I2_S: default_type = GGML_TYPE_I2_S; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: diff --git a/llama.h b/llama.h index f7cd33edc..7a2e0e31c 100644 --- a/llama.h +++ b/llama.h @@ -156,7 +156,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors - LLAMA_FTYPE_MOSTLY_I2_S = 33, + LLAMA_FTYPE_MOSTLY_Q2_2 = 33, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file };