Compare commits
25 commits
Author | SHA1 | Date | |
---|---|---|---|
|
e9f2abfc8c | ||
|
569a03ed97 | ||
|
95dced07e4 | ||
|
7a8961fff5 | ||
|
5e5eee7b44 | ||
|
f395dd9ca0 | ||
|
c0cd08d45e | ||
|
2322e9db9a | ||
|
de1d5073e4 | ||
|
c0fd4df883 | ||
|
841c903ff9 | ||
|
abd798d70f | ||
|
65ac3a3627 | ||
|
344467f2b8 | ||
|
97d22be58c | ||
|
3a0f8b0697 | ||
|
1c5a8b7fec | ||
|
dbee0a86c1 | ||
|
ca09085593 | ||
|
4e1ab50628 | ||
|
2a01a7ce0d | ||
|
5e59660173 | ||
|
1f2e0ee012 | ||
|
57dfc3bcdf | ||
|
076b4a197b |
11 changed files with 632 additions and 7 deletions
|
@ -1397,6 +1397,49 @@ class LlamaModel(Model):
|
||||||
raise ValueError(f"Unprocessed experts: {experts}")
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
|
@Model.register("BitnetForCausalLM")
|
||||||
|
class BitnetModel(Model):
|
||||||
|
model_arch = gguf.MODEL_ARCH.BITNET
|
||||||
|
|
||||||
|
def set_vocab(self):
|
||||||
|
self._set_vocab_sentencepiece()
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||||
|
|
||||||
|
def weight_quant(self, weight):
|
||||||
|
dtype = weight.dtype
|
||||||
|
weight = weight.float()
|
||||||
|
s = 1 / weight.abs().mean().clamp(min=1e-5)
|
||||||
|
result = (weight * s).round().clamp(-1, 1) / s
|
||||||
|
return result.type(dtype)
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
# transform weight into 1/0/-1 (in fp32)
|
||||||
|
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
|
||||||
|
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
|
||||||
|
"o_proj.weight")):
|
||||||
|
data_torch = self.weight_quant(data_torch)
|
||||||
|
|
||||||
|
# pad 1D tensors
|
||||||
|
# TODO: is padding with 0s an invariant, or do we also need some scaling factor?
|
||||||
|
if name.endswith(("input_layernorm.weight", "post_attention_layernorm.weight", "model.norm.weight")):
|
||||||
|
data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(0) % 256), mode='constant', value=0)
|
||||||
|
logger.info(f"pad {name} to {data_torch.size()}")
|
||||||
|
|
||||||
|
# pad 2D tensors
|
||||||
|
# TODO: double-check that this is the correct way to pad the rows
|
||||||
|
if name.endswith(("embed_tokens.weight", "q_proj.weight", "k_proj.weight", "v_proj.weight",
|
||||||
|
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
|
||||||
|
"o_proj.weight")):
|
||||||
|
data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(1) % 256), mode='constant', value=0)
|
||||||
|
logger.info(f"pad {name} to {data_torch.size()}")
|
||||||
|
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
@Model.register("GrokForCausalLM")
|
@Model.register("GrokForCausalLM")
|
||||||
class GrokModel(Model):
|
class GrokModel(Model):
|
||||||
model_arch = gguf.MODEL_ARCH.GROK
|
model_arch = gguf.MODEL_ARCH.GROK
|
||||||
|
|
|
@ -26,6 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||||
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
|
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
|
||||||
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
|
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
|
||||||
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
|
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
|
||||||
|
{ "I2_S", LLAMA_FTYPE_MOSTLY_I2_S, " 2 bpw per-tensor quantization", },
|
||||||
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
|
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
|
||||||
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
|
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
|
||||||
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
|
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
|
||||||
|
|
|
@ -1022,6 +1022,73 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
|
||||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
||||||
GGML_TABLE_END()
|
GGML_TABLE_END()
|
||||||
|
|
||||||
|
GGML_TABLE_BEGIN(uint32_t, i2s_i8s, 256)
|
||||||
|
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||||
|
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||||
|
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||||
|
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||||
|
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||||
|
0x00010100, 0x01010100, 0x00010100, 0xff010100,
|
||||||
|
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||||
|
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
|
||||||
|
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||||
|
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||||
|
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||||
|
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||||
|
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||||
|
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
|
||||||
|
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||||
|
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
|
||||||
|
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||||
|
0x00010001, 0x01010001, 0x00010001, 0xff010001,
|
||||||
|
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||||
|
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
|
||||||
|
0x00000101, 0x01000101, 0x00000101, 0xff000101,
|
||||||
|
0x00010101, 0x01010101, 0x00010101, 0xff010101,
|
||||||
|
0x00000101, 0x01000101, 0x00000101, 0xff000101,
|
||||||
|
0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101,
|
||||||
|
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||||
|
0x00010001, 0x01010001, 0x00010001, 0xff010001,
|
||||||
|
0x00000001, 0x01000001, 0x00000001, 0xff000001,
|
||||||
|
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
|
||||||
|
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
|
||||||
|
0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01,
|
||||||
|
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
|
||||||
|
0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01,
|
||||||
|
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||||
|
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||||
|
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||||
|
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||||
|
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||||
|
0x00010100, 0x01010100, 0x00010100, 0xff010100,
|
||||||
|
0x00000100, 0x01000100, 0x00000100, 0xff000100,
|
||||||
|
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
|
||||||
|
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||||
|
0x00010000, 0x01010000, 0x00010000, 0xff010000,
|
||||||
|
0x00000000, 0x01000000, 0x00000000, 0xff000000,
|
||||||
|
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
|
||||||
|
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||||
|
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
|
||||||
|
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
|
||||||
|
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
|
||||||
|
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||||
|
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
|
||||||
|
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||||
|
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
|
||||||
|
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
|
||||||
|
0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff,
|
||||||
|
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
|
||||||
|
0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff,
|
||||||
|
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||||
|
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
|
||||||
|
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
|
||||||
|
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
|
||||||
|
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
|
||||||
|
0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff,
|
||||||
|
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
|
||||||
|
0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff,
|
||||||
|
GGML_TABLE_END()
|
||||||
|
|
||||||
#define NGRID_IQ1S 2048
|
#define NGRID_IQ1S 2048
|
||||||
#define IQ1S_DELTA 0.125f
|
#define IQ1S_DELTA 0.125f
|
||||||
#define IQ1M_DELTA 0.125f
|
#define IQ1M_DELTA 0.125f
|
||||||
|
|
143
ggml-quants.c
143
ggml-quants.c
|
@ -659,6 +659,24 @@ static inline __m128i packNibbles( __m256i bytes ) {
|
||||||
}
|
}
|
||||||
#endif //__loongarch_asx
|
#endif //__loongarch_asx
|
||||||
|
|
||||||
|
void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) {
|
||||||
|
int8_t* dst = (int8_t*)y;
|
||||||
|
double min = 0.00001;
|
||||||
|
double max = min;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
max = MAX(max, (double)fabs((double)x[i]));
|
||||||
|
}
|
||||||
|
float s = 127 / max;
|
||||||
|
act_scales[0] = s;
|
||||||
|
float temp;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
temp = round((double)(x[i] * s));
|
||||||
|
if (temp > 127) temp = 127;
|
||||||
|
if (temp < -128) temp = -128;
|
||||||
|
dst[i] = (int8_t)(temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// reference implementation for deterministic creation of model files
|
// reference implementation for deterministic creation of model files
|
||||||
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
||||||
static const int qk = QK4_0;
|
static const int qk = QK4_0;
|
||||||
|
@ -3306,6 +3324,50 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
|
||||||
return nrow * row_size;
|
return nrow * row_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
|
// 2 bits per weight
|
||||||
|
UNUSED(quant_weights);
|
||||||
|
|
||||||
|
size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
|
||||||
|
|
||||||
|
int n = nrow * n_per_row;
|
||||||
|
|
||||||
|
// f32 -> q8
|
||||||
|
double max = 0;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
max = MAX(max, (double)fabs((double)src[i]));
|
||||||
|
}
|
||||||
|
double i2_scale = max;
|
||||||
|
|
||||||
|
uint8_t* q8 = (uint8_t*)dst;
|
||||||
|
for (int i=0; i<n; i++) {
|
||||||
|
if (fabs((double)(src[i])) < 1e-6) {
|
||||||
|
q8[i] = 0;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
q8[i] = (double)src[i] * i2_scale > 0 ? 1 : 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// q8 -> 0, 1, 3
|
||||||
|
// | | |
|
||||||
|
// 0, 1,-1
|
||||||
|
|
||||||
|
uint8_t* i2_weight = (uint8_t*)dst;
|
||||||
|
for (int i=0; i<n; i++) {
|
||||||
|
int group_idx = i / 4;
|
||||||
|
int group_pos = i % 4;
|
||||||
|
uint8_t temp = (q8[i] << (6 - 2 * group_pos));
|
||||||
|
q8[i] = 0;
|
||||||
|
i2_weight[group_idx] |= temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
float* scale_ptr = (float*)((char*)i2_weight + n / 4);
|
||||||
|
scale_ptr[0] = i2_scale;
|
||||||
|
|
||||||
|
// 32B for alignment
|
||||||
|
return nrow * row_size / 4 + 32;
|
||||||
|
}
|
||||||
|
|
||||||
// ====================== "True" 2-bit (de)-quantization
|
// ====================== "True" 2-bit (de)-quantization
|
||||||
|
|
||||||
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
|
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
|
||||||
|
@ -3726,6 +3788,86 @@ static inline __m128i get_scale_shuffle(int i) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
//====================================== I2 ===============================================
|
||||||
|
|
||||||
|
void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||||
|
const uint8_t * restrict x = vx;
|
||||||
|
const int8_t * restrict y = vy;
|
||||||
|
|
||||||
|
UNUSED(bs);
|
||||||
|
UNUSED(bx);
|
||||||
|
UNUSED(by);
|
||||||
|
UNUSED(nrc);
|
||||||
|
|
||||||
|
#if defined(__AVX2__)
|
||||||
|
__m256i accu = _mm256_setzero_si256();
|
||||||
|
|
||||||
|
// max group_size is 128 (2^8)
|
||||||
|
// limited by 8640 to 2 (8640 % (2 * 32) == 0)
|
||||||
|
int group_num = 2;
|
||||||
|
|
||||||
|
for (int i=0; i < n / (group_num * 32); i++){
|
||||||
|
__m256i laccu = _mm256_setzero_si256();
|
||||||
|
__m256i haccu = _mm256_setzero_si256();
|
||||||
|
|
||||||
|
for (int j=0; j < group_num; j++) {
|
||||||
|
__m256i xq8 = _mm256_set_epi32(
|
||||||
|
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 7]],
|
||||||
|
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 6]],
|
||||||
|
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 5]],
|
||||||
|
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 4]],
|
||||||
|
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 3]],
|
||||||
|
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 2]],
|
||||||
|
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 1]],
|
||||||
|
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 0]]
|
||||||
|
);
|
||||||
|
|
||||||
|
__m256i yq8 = _mm256_loadu_si256((const __m256i*)(y + i * group_num * 32 + j * 32));
|
||||||
|
|
||||||
|
__m128i hxq8 = _mm256_castsi256_si128(xq8);
|
||||||
|
__m128i lxq8 = _mm256_extractf128_si256(xq8, 1);
|
||||||
|
__m128i hyq8 = _mm256_castsi256_si128(yq8);
|
||||||
|
__m128i lyq8 = _mm256_extractf128_si256(yq8, 1);
|
||||||
|
|
||||||
|
__m256i hxq16 = _mm256_cvtepi8_epi16(hxq8);
|
||||||
|
__m256i lxq16 = _mm256_cvtepi8_epi16(lxq8);
|
||||||
|
__m256i hyq16 = _mm256_cvtepi8_epi16(hyq8);
|
||||||
|
__m256i lyq16 = _mm256_cvtepi8_epi16(lyq8);
|
||||||
|
|
||||||
|
__m256i hzq16 = _mm256_sign_epi16(hyq16, hxq16);
|
||||||
|
__m256i lzq16 = _mm256_sign_epi16(lyq16, lxq16);
|
||||||
|
|
||||||
|
haccu = _mm256_add_epi16(haccu, hzq16);
|
||||||
|
laccu = _mm256_add_epi16(laccu, lzq16);
|
||||||
|
}
|
||||||
|
|
||||||
|
__m256i hhzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(haccu));
|
||||||
|
__m256i hlzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(haccu, 1));
|
||||||
|
__m256i llzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(laccu));
|
||||||
|
__m256i lhzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(laccu, 1));
|
||||||
|
|
||||||
|
accu = _mm256_add_epi32(accu, hhzq32);
|
||||||
|
accu = _mm256_add_epi32(accu, hlzq32);
|
||||||
|
accu = _mm256_add_epi32(accu, llzq32);
|
||||||
|
accu = _mm256_add_epi32(accu, lhzq32);
|
||||||
|
}
|
||||||
|
int sumi = hsum_i32_8(accu);
|
||||||
|
*s = (float)sumi;
|
||||||
|
#else
|
||||||
|
|
||||||
|
int sumi = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < n / 4; i++) {
|
||||||
|
const int8_t* weight = (const int8_t *)(i2s_i8s + x[i]);
|
||||||
|
sumi += (int)y[i*4+0] * weight[0];
|
||||||
|
sumi += (int)y[i*4+1] * weight[1];
|
||||||
|
sumi += (int)y[i*4+2] * weight[2];
|
||||||
|
sumi += (int)y[i*4+3] * weight[3];
|
||||||
|
}
|
||||||
|
*s = (float)sumi;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||||
const int qk = QK8_0;
|
const int qk = QK8_0;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
@ -14367,6 +14509,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||||
case GGML_TYPE_I16:
|
case GGML_TYPE_I16:
|
||||||
case GGML_TYPE_I32:
|
case GGML_TYPE_I32:
|
||||||
case GGML_TYPE_I64:
|
case GGML_TYPE_I64:
|
||||||
|
case GGML_TYPE_I2_S:
|
||||||
// nothing to validate
|
// nothing to validate
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -51,6 +51,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
|
||||||
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
|
void quantize_row_i8_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, float* n);
|
||||||
|
|
||||||
// Dequantization
|
// Dequantization
|
||||||
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
|
@ -99,6 +100,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
|
||||||
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
void ggml_vec_dot_i2_i8_s (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||||
|
|
||||||
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
||||||
size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
|
@ -121,6 +123,7 @@ size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
||||||
size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
|
size_t quantize_i2_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
|
|
||||||
void iq2xs_init_impl(enum ggml_type type);
|
void iq2xs_init_impl(enum ggml_type type);
|
||||||
void iq2xs_free_impl(enum ggml_type type);
|
void iq2xs_free_impl(enum ggml_type type);
|
||||||
|
|
47
ggml.c
47
ggml.c
|
@ -908,6 +908,21 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
|
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
|
||||||
.vec_dot_type = GGML_TYPE_BF16,
|
.vec_dot_type = GGML_TYPE_BF16,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
},
|
||||||
|
[GGML_TYPE_I2_S] = {
|
||||||
|
.type_name = "i2_s",
|
||||||
|
.blck_size = 1,
|
||||||
|
.type_size = sizeof(int8_t),
|
||||||
|
.is_quantized = true,
|
||||||
|
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_i8_s,
|
||||||
|
.vec_dot_type = GGML_TYPE_I8_S,
|
||||||
|
.nrows = 1,
|
||||||
|
},
|
||||||
|
[GGML_TYPE_I8_S] = {
|
||||||
|
.type_name = "i8_s",
|
||||||
|
.blck_size = 1,
|
||||||
|
.type_size = sizeof(int8_t),
|
||||||
|
.is_quantized = true,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3056,6 +3071,9 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
||||||
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
|
||||||
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
|
||||||
}
|
}
|
||||||
|
if(tensor->type == GGML_TYPE_I2_S){
|
||||||
|
nbytes = nbytes / 4 + 32;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
|
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
|
||||||
|
@ -12271,6 +12289,10 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
// 16 * 2, accounting for mmla kernels
|
// 16 * 2, accounting for mmla kernels
|
||||||
float tmp[32];
|
float tmp[32];
|
||||||
|
|
||||||
|
// for per-tensor quant
|
||||||
|
const float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4));
|
||||||
|
const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10));
|
||||||
|
|
||||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
|
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
|
||||||
|
@ -12303,7 +12325,12 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
//}
|
//}
|
||||||
|
|
||||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
if (src0->type == GGML_TYPE_I2_S) {
|
||||||
|
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||||
|
tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale);
|
||||||
|
} else {
|
||||||
|
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
||||||
|
@ -12467,8 +12494,13 @@ UseGgmlGemm1:;
|
||||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||||
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
if (src0->type == GGML_TYPE_I2_S) {
|
||||||
wdata += row_size;
|
float* act_scales = (float*) ((char *) wdata + (ne11 * ne10));
|
||||||
|
quantize_row_i8_s((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4), ne10, act_scales + i11);
|
||||||
|
} else {
|
||||||
|
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
||||||
|
wdata += row_size;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14183,6 +14215,8 @@ static void ggml_compute_forward_clamp(
|
||||||
case GGML_TYPE_I32:
|
case GGML_TYPE_I32:
|
||||||
case GGML_TYPE_I64:
|
case GGML_TYPE_I64:
|
||||||
case GGML_TYPE_F64:
|
case GGML_TYPE_F64:
|
||||||
|
case GGML_TYPE_I2_S:
|
||||||
|
case GGML_TYPE_I8_S:
|
||||||
case GGML_TYPE_COUNT:
|
case GGML_TYPE_COUNT:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -21325,6 +21359,7 @@ size_t ggml_quantize_chunk(
|
||||||
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
|
case GGML_TYPE_I2_S: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
size_t elemsize = sizeof(ggml_fp16_t);
|
size_t elemsize = sizeof(ggml_fp16_t);
|
||||||
|
@ -21347,7 +21382,11 @@ size_t ggml_quantize_chunk(
|
||||||
assert(false);
|
assert(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(result == nrows * row_size);
|
if (type == GGML_TYPE_I2_S) {
|
||||||
|
result = nrows * row_size / 4 + 32;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(result == nrows * row_size);
|
||||||
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
2
ggml.h
2
ggml.h
|
@ -377,6 +377,8 @@ extern "C" {
|
||||||
GGML_TYPE_F64 = 28,
|
GGML_TYPE_F64 = 28,
|
||||||
GGML_TYPE_IQ1_M = 29,
|
GGML_TYPE_IQ1_M = 29,
|
||||||
GGML_TYPE_BF16 = 30,
|
GGML_TYPE_BF16 = 30,
|
||||||
|
GGML_TYPE_I2_S = 31,
|
||||||
|
GGML_TYPE_I8_S = 32,
|
||||||
GGML_TYPE_COUNT,
|
GGML_TYPE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -148,6 +148,7 @@ class MODEL_ARCH(IntEnum):
|
||||||
OLMO = auto()
|
OLMO = auto()
|
||||||
ARCTIC = auto()
|
ARCTIC = auto()
|
||||||
DEEPSEEK2 = auto()
|
DEEPSEEK2 = auto()
|
||||||
|
BITNET = auto()
|
||||||
|
|
||||||
|
|
||||||
class MODEL_TENSOR(IntEnum):
|
class MODEL_TENSOR(IntEnum):
|
||||||
|
@ -199,6 +200,8 @@ class MODEL_TENSOR(IntEnum):
|
||||||
ATTN_KV_B = auto()
|
ATTN_KV_B = auto()
|
||||||
ATTN_Q_A_NORM = auto()
|
ATTN_Q_A_NORM = auto()
|
||||||
ATTN_KV_A_NORM = auto()
|
ATTN_KV_A_NORM = auto()
|
||||||
|
FFN_SUB_NORM = auto()
|
||||||
|
ATTN_SUB_NORM = auto()
|
||||||
|
|
||||||
|
|
||||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
|
@ -236,6 +239,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
MODEL_ARCH.OLMO: "olmo",
|
MODEL_ARCH.OLMO: "olmo",
|
||||||
MODEL_ARCH.ARCTIC: "arctic",
|
MODEL_ARCH.ARCTIC: "arctic",
|
||||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||||
|
MODEL_ARCH.BITNET: "bitnet",
|
||||||
}
|
}
|
||||||
|
|
||||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
|
@ -287,6 +291,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||||
|
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
|
||||||
|
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
|
@ -807,6 +813,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||||
],
|
],
|
||||||
|
MODEL_ARCH.BITNET: [
|
||||||
|
MODEL_TENSOR.ATTN_Q,
|
||||||
|
MODEL_TENSOR.ATTN_K,
|
||||||
|
MODEL_TENSOR.ATTN_V,
|
||||||
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
MODEL_TENSOR.OUTPUT_NORM,
|
||||||
|
MODEL_TENSOR.ROPE_FREQS,
|
||||||
|
MODEL_TENSOR.ATTN_NORM,
|
||||||
|
MODEL_TENSOR.ATTN_QKV,
|
||||||
|
MODEL_TENSOR.ATTN_OUT,
|
||||||
|
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||||
|
MODEL_TENSOR.FFN_NORM,
|
||||||
|
MODEL_TENSOR.FFN_GATE,
|
||||||
|
MODEL_TENSOR.FFN_DOWN,
|
||||||
|
MODEL_TENSOR.FFN_UP,
|
||||||
|
MODEL_TENSOR.ATTN_SUB_NORM,
|
||||||
|
MODEL_TENSOR.FFN_SUB_NORM,
|
||||||
|
],
|
||||||
# TODO
|
# TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -413,6 +413,14 @@ class TensorNameMap:
|
||||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.ATTN_SUB_NORM: (
|
||||||
|
"model.layers.{bid}.self_attn.inner_attn_ln", # bitnet
|
||||||
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.FFN_SUB_NORM: (
|
||||||
|
"model.layers.{bid}.mlp.ffn_layernorm", # bitnet
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# architecture-specific block mappings
|
# architecture-specific block mappings
|
||||||
|
|
300
llama.cpp
300
llama.cpp
|
@ -221,6 +221,7 @@ enum llm_arch {
|
||||||
LLM_ARCH_OLMO,
|
LLM_ARCH_OLMO,
|
||||||
LLM_ARCH_ARCTIC,
|
LLM_ARCH_ARCTIC,
|
||||||
LLM_ARCH_DEEPSEEK2,
|
LLM_ARCH_DEEPSEEK2,
|
||||||
|
LLM_ARCH_BITNET,
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -259,6 +260,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||||
{ LLM_ARCH_OLMO, "olmo" },
|
{ LLM_ARCH_OLMO, "olmo" },
|
||||||
{ LLM_ARCH_ARCTIC, "arctic" },
|
{ LLM_ARCH_ARCTIC, "arctic" },
|
||||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||||
|
{ LLM_ARCH_BITNET, "bitnet" },
|
||||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -494,6 +496,8 @@ enum llm_tensor {
|
||||||
LLM_TENSOR_ATTN_KV_B,
|
LLM_TENSOR_ATTN_KV_B,
|
||||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||||
|
LLM_TENSOR_ATTN_SUB_NORM,
|
||||||
|
LLM_TENSOR_FFN_SUB_NORM,
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||||
|
@ -1107,6 +1111,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
LLM_ARCH_BITNET,
|
||||||
|
{
|
||||||
|
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||||
|
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||||
|
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||||
|
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||||
|
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||||
|
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||||
|
{ LLM_TENSOR_ATTN_SUB_NORM, "blk.%d.attn_sub_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
LLM_ARCH_UNKNOWN,
|
LLM_ARCH_UNKNOWN,
|
||||||
{
|
{
|
||||||
|
@ -1984,6 +2006,8 @@ struct llama_layer {
|
||||||
struct ggml_tensor * attn_out_norm_b;
|
struct ggml_tensor * attn_out_norm_b;
|
||||||
struct ggml_tensor * attn_q_a_norm;
|
struct ggml_tensor * attn_q_a_norm;
|
||||||
struct ggml_tensor * attn_kv_a_norm;
|
struct ggml_tensor * attn_kv_a_norm;
|
||||||
|
struct ggml_tensor * attn_sub_norm;
|
||||||
|
struct ggml_tensor * ffn_sub_norm;
|
||||||
|
|
||||||
// attention
|
// attention
|
||||||
struct ggml_tensor * wq;
|
struct ggml_tensor * wq;
|
||||||
|
@ -1997,9 +2021,9 @@ struct llama_layer {
|
||||||
struct ggml_tensor * wkv_b;
|
struct ggml_tensor * wkv_b;
|
||||||
|
|
||||||
// attention bias
|
// attention bias
|
||||||
struct ggml_tensor * bq;
|
struct ggml_tensor * bq = nullptr;
|
||||||
struct ggml_tensor * bk;
|
struct ggml_tensor * bk = nullptr;
|
||||||
struct ggml_tensor * bv;
|
struct ggml_tensor * bv = nullptr;
|
||||||
struct ggml_tensor * bo;
|
struct ggml_tensor * bo;
|
||||||
struct ggml_tensor * bqkv;
|
struct ggml_tensor * bqkv;
|
||||||
|
|
||||||
|
@ -4492,6 +4516,15 @@ static void llm_load_hparams(
|
||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_BITNET:
|
||||||
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
|
||||||
|
switch (hparams.n_layer) {
|
||||||
|
case 26: model.type = e_model::MODEL_3B; break;
|
||||||
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6405,6 +6438,44 @@ static bool llm_load_tensors(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_BITNET:
|
||||||
|
{
|
||||||
|
const uint32_t n_ff = hparams.n_ff;
|
||||||
|
const uint32_t n_ff_pad = GGML_PAD(n_ff, 256);
|
||||||
|
|
||||||
|
const int64_t n_embd_pad = GGML_PAD(n_embd, 256);
|
||||||
|
|
||||||
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd_pad, n_vocab});
|
||||||
|
|
||||||
|
// output
|
||||||
|
{
|
||||||
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd_pad});
|
||||||
|
}
|
||||||
|
|
||||||
|
model.layers.resize(n_layer);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||||
|
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||||
|
|
||||||
|
auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd_pad});
|
||||||
|
layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
|
||||||
|
|
||||||
|
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_pad, n_embd});
|
||||||
|
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_pad, n_embd_gqa});
|
||||||
|
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_pad, n_embd_gqa});
|
||||||
|
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_pad, n_embd});
|
||||||
|
|
||||||
|
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd_pad});
|
||||||
|
layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
|
||||||
|
|
||||||
|
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd_pad, n_ff});
|
||||||
|
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_pad, n_embd});
|
||||||
|
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd_pad, n_ff});
|
||||||
|
}
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error("unknown architecture");
|
throw std::runtime_error("unknown architecture");
|
||||||
}
|
}
|
||||||
|
@ -11449,6 +11520,223 @@ struct llm_build_context {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph * build_bitnet() {
|
||||||
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
struct ggml_tensor * inpL;
|
||||||
|
|
||||||
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||||
|
|
||||||
|
// inp_pos - contains the positions
|
||||||
|
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
|
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||||
|
|
||||||
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
struct ggml_tensor * inpSA = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
model.layers[il].attn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
|
// self-attention
|
||||||
|
{
|
||||||
|
// compute Q and K and RoPE them
|
||||||
|
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
if (model.layers[il].bq) {
|
||||||
|
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// B1.K
|
||||||
|
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
if (model.layers[il].bk) {
|
||||||
|
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
// B1.V
|
||||||
|
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
if (model.layers[il].bv) {
|
||||||
|
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||||
|
cb(Vcur, "Vcur", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
Qcur = ggml_rope_ext(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
Kcur = ggml_rope_ext(
|
||||||
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
|
ext_factor, attn_factor, beta_fast, beta_slow
|
||||||
|
);
|
||||||
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
llm_build_kv_store(ctx0, hparams, cparams, kv_self, gf, Kcur, Vcur, n_tokens, kv_head, cb, il);
|
||||||
|
|
||||||
|
const int64_t n_ctx = cparams.n_ctx;
|
||||||
|
const int64_t n_head = hparams.n_head;
|
||||||
|
const int64_t n_head_kv = hparams.n_head_kv;
|
||||||
|
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||||
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||||
|
const int64_t n_embd_head_v = hparams.n_embd_head_v;
|
||||||
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
|
struct ggml_tensor * q_cur = Qcur;
|
||||||
|
struct ggml_tensor * kq_mask = KQ_mask;
|
||||||
|
float kq_scale = 1.0f/sqrtf(float(n_embd_head));
|
||||||
|
struct ggml_tensor * attn_sub_norm = model.layers[il].attn_sub_norm;
|
||||||
|
struct ggml_cgraph * graph = gf;
|
||||||
|
struct ggml_tensor * wo = model.layers[il].wo;
|
||||||
|
struct ggml_tensor * cur_attn;
|
||||||
|
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
|
||||||
|
cb(q, "q", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * k =
|
||||||
|
ggml_view_3d(ctx0, kv_self.k_l[il],
|
||||||
|
n_embd_head_k, n_kv, n_head_kv,
|
||||||
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
|
||||||
|
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
|
||||||
|
0);
|
||||||
|
cb(k, "k", il);
|
||||||
|
|
||||||
|
if (cparams.flash_attn) {
|
||||||
|
|
||||||
|
// split cached v into n_head heads (not transposed)
|
||||||
|
struct ggml_tensor * v =
|
||||||
|
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||||
|
n_embd_head_v, n_kv, n_head_kv,
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
|
||||||
|
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
|
||||||
|
0);
|
||||||
|
cb(v, "v", il);
|
||||||
|
|
||||||
|
cur_attn = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
|
|
||||||
|
cur_attn = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
|
||||||
|
} else {
|
||||||
|
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||||
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
|
|
||||||
|
GGML_ASSERT(kv_self.size == n_ctx);
|
||||||
|
|
||||||
|
// split cached v into n_head heads
|
||||||
|
struct ggml_tensor * v =
|
||||||
|
ggml_view_3d(ctx0, kv_self.v_l[il],
|
||||||
|
n_kv, n_embd_head_v, n_head_kv,
|
||||||
|
ggml_element_size(kv_self.v_l[il])*n_ctx,
|
||||||
|
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
|
||||||
|
0);
|
||||||
|
cb(v, "v", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||||
|
cb(kqv, "kqv", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||||
|
cb(kqv_merged, "kqv_merged", il);
|
||||||
|
|
||||||
|
cur_attn = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
|
||||||
|
cb(cur_attn, "kqv_merged_cont", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur_attn = llm_build_norm(ctx0, cur_attn, hparams,
|
||||||
|
attn_sub_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur_attn, "attn_sub_norm", il);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(graph, cur_attn);
|
||||||
|
|
||||||
|
cur_attn = ggml_pad(ctx0, cur_attn, (256 - cur_attn->ne[0] % 256) % 256, 0, 0, 0);
|
||||||
|
cur = ggml_mul_mat(ctx0, wo, cur_attn);
|
||||||
|
|
||||||
|
cb(cur, "kqv_out", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (il == n_layer - 1) {
|
||||||
|
// skip computing output for unused tokens
|
||||||
|
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||||
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0);
|
||||||
|
|
||||||
|
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
|
// feed-forward forward
|
||||||
|
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||||
|
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||||
|
model.layers[il].ffn_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
|
||||||
|
|
||||||
|
cb(tmp, "ffn_up", il);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur);
|
||||||
|
|
||||||
|
cb(cur, "ffn_gate", il);
|
||||||
|
|
||||||
|
|
||||||
|
cur = ggml_silu(ctx0, cur);
|
||||||
|
cb(cur, "ffn_silu", il);
|
||||||
|
|
||||||
|
cur = ggml_mul(ctx0, cur, tmp);
|
||||||
|
cb(cur, "ffn_gate_par", il);
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.layers[il].ffn_sub_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, il);
|
||||||
|
cb(cur, "ffn_sub_norm", il);
|
||||||
|
|
||||||
|
cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0);
|
||||||
|
|
||||||
|
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
|
||||||
|
cb(cur, "ffn_down", il);
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0);
|
||||||
|
|
||||||
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
|
cb(cur, "l_out", il);
|
||||||
|
|
||||||
|
// input for next layer
|
||||||
|
inpL = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
|
cur = llm_build_norm(ctx0, cur, hparams,
|
||||||
|
model.output_norm, NULL,
|
||||||
|
LLM_NORM_RMS, cb, -1);
|
||||||
|
cb(cur, "result_norm", -1);
|
||||||
|
|
||||||
|
// lm_head
|
||||||
|
cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
|
||||||
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
return gf;
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||||
|
@ -11671,6 +11959,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
{
|
{
|
||||||
result = llm.build_deepseek2();
|
result = llm.build_deepseek2();
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_ARCH_BITNET:
|
||||||
|
{
|
||||||
|
result = llm.build_bitnet();
|
||||||
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
@ -15172,6 +15464,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
|
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
|
||||||
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
||||||
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
||||||
|
case LLAMA_FTYPE_MOSTLY_I2_S: default_type = GGML_TYPE_I2_S; break;
|
||||||
|
|
||||||
// K-quants
|
// K-quants
|
||||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
|
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
|
||||||
|
@ -16440,6 +16733,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||||
case LLM_ARCH_BERT:
|
case LLM_ARCH_BERT:
|
||||||
case LLM_ARCH_NOMIC_BERT:
|
case LLM_ARCH_NOMIC_BERT:
|
||||||
case LLM_ARCH_STABLELM:
|
case LLM_ARCH_STABLELM:
|
||||||
|
case LLM_ARCH_BITNET:
|
||||||
case LLM_ARCH_QWEN:
|
case LLM_ARCH_QWEN:
|
||||||
case LLM_ARCH_QWEN2:
|
case LLM_ARCH_QWEN2:
|
||||||
case LLM_ARCH_QWEN2MOE:
|
case LLM_ARCH_QWEN2MOE:
|
||||||
|
|
1
llama.h
1
llama.h
|
@ -156,6 +156,7 @@ extern "C" {
|
||||||
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
|
||||||
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_I2_S = 33,
|
||||||
|
|
||||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||||
};
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue