diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 025405a2c..6bb709951 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1397,6 +1397,49 @@ class LlamaModel(Model): raise ValueError(f"Unprocessed experts: {experts}") +@Model.register("BitnetForCausalLM") +class BitnetModel(Model): + model_arch = gguf.MODEL_ARCH.BITNET + + def set_vocab(self): + self._set_vocab_sentencepiece() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(1.0) + + def weight_quant(self, weight): + dtype = weight.dtype + weight = weight.float() + s = 1 / weight.abs().mean().clamp(min=1e-5) + result = (weight * s).round().clamp(-1, 1) / s + return result.type(dtype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # transform weight into 1/0/-1 (in fp32) + if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): + data_torch = self.weight_quant(data_torch) + + # pad 1D tensors + # TODO: is padding with 0s an invariant, or do we also need some scaling factor? + if name.endswith(("input_layernorm.weight", "post_attention_layernorm.weight", "model.norm.weight")): + data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(0) % 256), mode='constant', value=0) + logger.info(f"pad {name} to {data_torch.size()}") + + # pad 2D tensors + # TODO: double-check that this is the correct way to pad the rows + if name.endswith(("embed_tokens.weight", "q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): + data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(1) % 256), mode='constant', value=0) + logger.info(f"pad {name} to {data_torch.size()}") + + return [(self.map_tensor_name(name), data_torch)] + + @Model.register("GrokForCausalLM") class GrokModel(Model): model_arch = gguf.MODEL_ARCH.GROK diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 28584e14b..16cfd1717 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,6 +26,7 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, + { "I2_S", LLAMA_FTYPE_MOSTLY_I2_S, " 2 bpw per-tensor quantization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", }, { "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", }, diff --git a/ggml-common.h b/ggml-common.h index e8efceb76..409fcf29e 100644 --- a/ggml-common.h +++ b/ggml-common.h @@ -1022,6 +1022,73 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() +GGML_TABLE_BEGIN(uint32_t, i2s_i8s, 256) + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00010000, 0x01010000, 0x00010000, 0xff010000, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, + 0x00000100, 0x01000100, 0x00000100, 0xff000100, + 0x00010100, 0x01010100, 0x00010100, 0xff010100, + 0x00000100, 0x01000100, 0x00000100, 0xff000100, + 0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00010000, 0x01010000, 0x00010000, 0xff010000, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, + 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, + 0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, + 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, + 0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, + 0x00000001, 0x01000001, 0x00000001, 0xff000001, + 0x00010001, 0x01010001, 0x00010001, 0xff010001, + 0x00000001, 0x01000001, 0x00000001, 0xff000001, + 0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, + 0x00000101, 0x01000101, 0x00000101, 0xff000101, + 0x00010101, 0x01010101, 0x00010101, 0xff010101, + 0x00000101, 0x01000101, 0x00000101, 0xff000101, + 0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101, + 0x00000001, 0x01000001, 0x00000001, 0xff000001, + 0x00010001, 0x01010001, 0x00010001, 0xff010001, + 0x00000001, 0x01000001, 0x00000001, 0xff000001, + 0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001, + 0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, + 0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01, + 0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01, + 0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00010000, 0x01010000, 0x00010000, 0xff010000, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, + 0x00000100, 0x01000100, 0x00000100, 0xff000100, + 0x00010100, 0x01010100, 0x00010100, 0xff010100, + 0x00000100, 0x01000100, 0x00000100, 0xff000100, + 0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00010000, 0x01010000, 0x00010000, 0xff010000, + 0x00000000, 0x01000000, 0x00000000, 0xff000000, + 0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000, + 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, + 0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00, + 0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00, + 0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00, + 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, + 0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, + 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, + 0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, + 0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, + 0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff, + 0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff, + 0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff, + 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, + 0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff, + 0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff, + 0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff, + 0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, + 0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff, + 0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff, + 0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff, +GGML_TABLE_END() + #define NGRID_IQ1S 2048 #define IQ1S_DELTA 0.125f #define IQ1M_DELTA 0.125f diff --git a/ggml-quants.c b/ggml-quants.c index 9f864e5c4..4b5209279 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -659,6 +659,24 @@ static inline __m128i packNibbles( __m256i bytes ) { } #endif //__loongarch_asx +void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) { + int8_t* dst = (int8_t*)y; + double min = 0.00001; + double max = min; + for (int i = 0; i < n; ++i) { + max = MAX(max, (double)fabs((double)x[i])); + } + float s = 127 / max; + act_scales[0] = s; + float temp; + for (int i = 0; i < n; ++i) { + temp = round((double)(x[i] * s)); + if (temp > 127) temp = 127; + if (temp < -128) temp = -128; + dst[i] = (int8_t)(temp); + } +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -3306,6 +3324,50 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } +size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + // 2 bits per weight + UNUSED(quant_weights); + + size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); + + int n = nrow * n_per_row; + + // f32 -> q8 + double max = 0; + for (int i = 0; i < n; ++i) { + max = MAX(max, (double)fabs((double)src[i])); + } + double i2_scale = max; + + uint8_t* q8 = (uint8_t*)dst; + for (int i=0; i 0 ? 1 : 3; + } + + // q8 -> 0, 1, 3 + // | | | + // 0, 1,-1 + + uint8_t* i2_weight = (uint8_t*)dst; + for (int i=0; ine[i] - 1)*tensor->nb[i]; } + if(tensor->type == GGML_TYPE_I2_S){ + nbytes = nbytes / 4 + 32; + } } else { nbytes = tensor->ne[0]*tensor->nb[0]/blck_size; @@ -12271,6 +12289,10 @@ static void ggml_compute_forward_mul_mat_one_chunk( // 16 * 2, accounting for mmla kernels float tmp[32]; + // for per-tensor quant + const float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4)); + const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10)); + for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) { @@ -12303,7 +12325,12 @@ static void ggml_compute_forward_mul_mat_one_chunk( //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) { - vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + if (src0->type == GGML_TYPE_I2_S) { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale); + } else { + vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot); + } } for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) { @@ -12467,8 +12494,13 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = 0; i11 < ne11; ++i11) { - from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); - wdata += row_size; + if (src0->type == GGML_TYPE_I2_S) { + float* act_scales = (float*) ((char *) wdata + (ne11 * ne10)); + quantize_row_i8_s((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4), ne10, act_scales + i11); + } else { + from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10); + wdata += row_size; + } } } } @@ -14183,6 +14215,8 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: + case GGML_TYPE_I2_S: + case GGML_TYPE_I8_S: case GGML_TYPE_COUNT: { GGML_ASSERT(false); @@ -21325,6 +21359,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_I2_S: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t); @@ -21347,7 +21382,11 @@ size_t ggml_quantize_chunk( assert(false); } - GGML_ASSERT(result == nrows * row_size); + if (type == GGML_TYPE_I2_S) { + result = nrows * row_size / 4 + 32; + } else { + GGML_ASSERT(result == nrows * row_size); + } return result; } diff --git a/ggml.h b/ggml.h index 13502a362..c2e6859f5 100644 --- a/ggml.h +++ b/ggml.h @@ -377,6 +377,8 @@ extern "C" { GGML_TYPE_F64 = 28, GGML_TYPE_IQ1_M = 29, GGML_TYPE_BF16 = 30, + GGML_TYPE_I2_S = 31, + GGML_TYPE_I8_S = 32, GGML_TYPE_COUNT, }; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8908585cc..78c7290d2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -148,6 +148,7 @@ class MODEL_ARCH(IntEnum): OLMO = auto() ARCTIC = auto() DEEPSEEK2 = auto() + BITNET = auto() class MODEL_TENSOR(IntEnum): @@ -199,6 +200,8 @@ class MODEL_TENSOR(IntEnum): ATTN_KV_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() + FFN_SUB_NORM = auto() + ATTN_SUB_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -236,6 +239,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.OLMO: "olmo", MODEL_ARCH.ARCTIC: "arctic", MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.BITNET: "bitnet", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -287,6 +291,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", + MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -807,6 +813,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.BITNET: [ + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_SUB_NORM, + MODEL_TENSOR.FFN_SUB_NORM, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 81b4992a5..350035bd9 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -413,6 +413,14 @@ class TensorNameMap: MODEL_TENSOR.ATTN_KV_A_NORM: ( "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 ), + + MODEL_TENSOR.ATTN_SUB_NORM: ( + "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet + ), + + MODEL_TENSOR.FFN_SUB_NORM: ( + "model.layers.{bid}.mlp.ffn_layernorm", # bitnet + ), } # architecture-specific block mappings diff --git a/llama.cpp b/llama.cpp index 8b675ea99..085a5a236 100644 --- a/llama.cpp +++ b/llama.cpp @@ -221,6 +221,7 @@ enum llm_arch { LLM_ARCH_OLMO, LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_BITNET, LLM_ARCH_UNKNOWN, }; @@ -259,6 +260,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -494,6 +496,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_KV_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_SUB_NORM, + LLM_TENSOR_FFN_SUB_NORM, }; static const std::map> LLM_TENSOR_NAMES = { @@ -1107,6 +1111,24 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_BITNET, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1984,6 +2006,8 @@ struct llama_layer { struct ggml_tensor * attn_out_norm_b; struct ggml_tensor * attn_q_a_norm; struct ggml_tensor * attn_kv_a_norm; + struct ggml_tensor * attn_sub_norm; + struct ggml_tensor * ffn_sub_norm; // attention struct ggml_tensor * wq; @@ -1997,9 +2021,9 @@ struct llama_layer { struct ggml_tensor * wkv_b; // attention bias - struct ggml_tensor * bq; - struct ggml_tensor * bk; - struct ggml_tensor * bv; + struct ggml_tensor * bq = nullptr; + struct ggml_tensor * bk = nullptr; + struct ggml_tensor * bv = nullptr; struct ggml_tensor * bo; struct ggml_tensor * bqkv; @@ -4492,6 +4516,15 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BITNET: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 26: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -6405,6 +6438,44 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_BITNET: + { + const uint32_t n_ff = hparams.n_ff; + const uint32_t n_ff_pad = GGML_PAD(n_ff, 256); + + const int64_t n_embd_pad = GGML_PAD(n_embd, 256); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd_pad, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd_pad}); + } + + model.layers.resize(n_layer); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd_pad}); + layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_pad, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_pad, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_pad, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_pad, n_embd}); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd_pad}); + layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd_pad, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_pad, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd_pad, n_ff}); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -11449,6 +11520,223 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_bitnet() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + // B1.K + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + // B1.V + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il); + + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + struct ggml_tensor * q_cur = Qcur; + struct ggml_tensor * kq_mask = KQ_mask; + float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm; + struct ggml_cgraph * graph = gf; + struct ggml_tensor * wo = model.layers[il].wo; + struct ggml_tensor * cur_attn; + struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + cb(q, "q", il); + + struct ggml_tensor * k = + ggml_view_3d(ctx0, kv_self.k_l[il], + n_embd_head_k, n_kv, n_head_kv, + ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), + 0); + cb(k, "k", il); + + if (cparams.flash_attn) { + + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v), + 0); + cb(v, "v", il); + + cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); + + cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); + } else { + struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + cb(kq, "kq", il); + + kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); + cb(kq, "kq_soft_max_ext", il); + + GGML_ASSERT(kv_self.size == n_ctx); + + // split cached v into n_head heads + struct ggml_tensor * v = + ggml_view_3d(ctx0, kv_self.v_l[il], + n_kv, n_embd_head_v, n_head_kv, + ggml_element_size(kv_self.v_l[il])*n_ctx, + ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, + 0); + cb(v, "v", il); + + struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + cb(kqv, "kqv", il); + + struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cb(kqv_merged, "kqv_merged", il); + + cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cb(cur_attn, "kqv_merged_cont", il); + } + + cur_attn = llm_build_norm(ctx0, cur_attn, hparams, + attn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur_attn, "attn_sub_norm", il); + + ggml_build_forward_expand(graph, cur_attn); + + cur_attn = ggml_pad(ctx0, cur_attn, (256 - cur_attn->ne[0] % 256) % 256, 0, 0, 0); + cur = ggml_mul_mat(ctx0, wo, cur_attn); + + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0); + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward forward + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + + cb(tmp, "ffn_up", il); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); + + cb(cur, "ffn_gate", il); + + + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_silu", il); + + cur = ggml_mul(ctx0, cur, tmp); + cb(cur, "ffn_gate_par", il); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].ffn_sub_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_sub_norm", il); + + cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0); + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); + cb(cur, "ffn_down", il); + } + + cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; + } + }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -11671,6 +11959,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_deepseek2(); } break; + case LLM_ARCH_BITNET: + { + result = llm.build_bitnet(); + } break; default: GGML_ASSERT(false); } @@ -15172,6 +15464,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; + case LLAMA_FTYPE_MOSTLY_I2_S: default_type = GGML_TYPE_I2_S; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: @@ -16440,6 +16733,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_STABLELM: + case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: diff --git a/llama.h b/llama.h index 62908261f..f7cd33edc 100644 --- a/llama.h +++ b/llama.h @@ -156,6 +156,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors + LLAMA_FTYPE_MOSTLY_I2_S = 33, LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file };