This commit is contained in:
Eddie-Wang1120 2024-06-17 20:33:09 +08:00
parent 569a03ed97
commit a03eff318c
10 changed files with 220 additions and 193 deletions

View file

@ -1413,17 +1413,47 @@ class BitnetModel(Model):
dtype = weight.dtype
weight = weight.float()
s = 1 / weight.abs().mean().clamp(min=1e-5)
result = (weight * s).round().clamp(-1, 1) / s
return result.type(dtype)
weight = (weight * s).round().clamp(-1, 1) / s
scale = weight.abs().max().unsqueeze(0)
weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype)
weight = torch.sign(weight).type(dtype)
return weight.type(dtype), scale.type(torch.float32)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# transform weight into 1/0/-1 (in fp32)
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
"o_proj.weight")):
data_torch = self.weight_quant(data_torch)
weight_torch, scale_torch = self.weight_quant(data_torch)
return [(self.map_tensor_name(name), data_torch)]
tensors: list[tuple[str, Tensor]] = []
if name.endswith("q_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q_SCALE, bid), scale_torch))
elif name.endswith("k_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K_SCALE, bid), scale_torch))
elif name.endswith("v_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V_SCALE, bid), scale_torch))
elif name.endswith("o_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT_SCALE, bid), scale_torch))
elif name.endswith("up_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SCALE, bid), scale_torch))
elif name.endswith("down_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN_SCALE, bid), scale_torch))
elif name.endswith("gate_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SCALE, bid), scale_torch))
if len(tensors) == 0:
tensors.append((self.map_tensor_name(name), data_torch))
return tensors
@Model.register("GrokForCausalLM")

View file

@ -26,7 +26,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "I2_S", LLAMA_FTYPE_MOSTLY_I2_S, " 2 bpw per-tensor quantization", },
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2 bpw quantization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },

View file

@ -137,6 +137,13 @@ typedef sycl::half2 ggml_half2;
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
#define QK2_2 32
typedef struct {
ggml_half d; // delta
uint8_t qs[QK2_2 / 4]; // nibbles / quants
} block_q2_2;
static_assert(sizeof(block_q2_2) == sizeof(ggml_half) + QK2_2 / 4, "wrong q4_0 block size/padding");
#define QK4_0 32
typedef struct {
ggml_half d; // delta
@ -1022,7 +1029,7 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
GGML_TABLE_END()
GGML_TABLE_BEGIN(uint32_t, i2s_i8s, 256)
GGML_TABLE_BEGIN(uint32_t, q22_grid, 256)
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00010000, 0x01010000, 0x00010000, 0xff010000,
0x00000000, 0x01000000, 0x00000000, 0xff000000,

View file

@ -659,25 +659,44 @@ static inline __m128i packNibbles( __m256i bytes ) {
}
#endif //__loongarch_asx
void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) {
int8_t* dst = (int8_t*)y;
double min = 0.00001;
double max = min;
for (int i = 0; i < n; ++i) {
max = MAX(max, (double)fabs((double)x[i]));
void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict y, int64_t k) {
static const int qk = QK2_2;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const float d = 1.0f;
y[i].d = GGML_FP32_TO_FP16(d);
for (int j = 0; j < qk/4; ++j) {
int8_t x0 = (int8_t)x[i*qk + j*4 + 0];
int8_t x1 = (int8_t)x[i*qk + j*4 + 1];
int8_t x2 = (int8_t)x[i*qk + j*4 + 2];
int8_t x3 = (int8_t)x[i*qk + j*4 + 3];
const uint8_t xi0 = x0 >= 0 ? x0 : 3;
const uint8_t xi1 = x1 >= 0 ? x1 : 3;
const uint8_t xi2 = x2 >= 0 ? x2 : 3;
const uint8_t xi3 = x3 >= 0 ? x3 : 3;
y[i].qs[j] = 0;
y[i].qs[j] |= (xi0 << 6);
y[i].qs[j] |= (xi1 << 4);
y[i].qs[j] |= (xi2 << 2);
y[i].qs[j] |= (xi3 << 0);
}
float s = 127 / max;
act_scales[0] = s;
float temp;
for (int i = 0; i < n; ++i) {
temp = round((double)(x[i] * s));
if (temp > 127) temp = 127;
if (temp < -128) temp = -128;
dst[i] = (int8_t)(temp);
}
}
// reference implementation for deterministic creation of model files
void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q2_2_reference(x, y, k);
}
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
static const int qk = QK4_0;
@ -3324,48 +3343,11 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
return nrow * row_size;
}
size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
// 2 bits per weight
UNUSED(quant_weights);
size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
int n = nrow * n_per_row;
// f32 -> q8
double max = 0;
for (int i = 0; i < n; ++i) {
max = MAX(max, (double)fabs((double)src[i]));
}
double i2_scale = max;
uint8_t* q8 = (uint8_t*)dst;
for (int i=0; i<n; i++) {
if (fabs((double)(src[i])) < 1e-6) {
q8[i] = 0;
continue;
}
q8[i] = (double)src[i] * i2_scale > 0 ? 1 : 3;
}
// q8 -> 0, 1, 3
// | | |
// 0, 1,-1
uint8_t* i2_weight = (uint8_t*)dst;
for (int i=0; i<n; i++) {
int group_idx = i / 4;
int group_pos = i % 4;
uint8_t temp = (q8[i] << (6 - 2 * group_pos));
q8[i] = 0;
i2_weight[group_idx] |= temp;
}
float* scale_ptr = (float*)((char*)i2_weight + n / 4);
scale_ptr[0] = i2_scale;
// 32B for alignment
return nrow * row_size / 4 + 32;
size_t quantize_q2_2(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q2_2, n_per_row);
quantize_row_q2_2_reference(src, dst, (int64_t)nrow*n_per_row);
return nrow * row_size;
}
// ====================== "True" 2-bit (de)-quantization
@ -3788,83 +3770,59 @@ static inline __m128i get_scale_shuffle(int i) {
}
#endif
//====================================== I2 ===============================================
void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const uint8_t * restrict x = vx;
const int8_t * restrict y = vy;
UNUSED(bs);
assert(n % qk == 0);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(nrc);
UNUSED(bs);
const block_q2_2 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#if defined(__AVX2__)
__m256i accu = _mm256_setzero_si256();
__m256 acc = _mm256_setzero_ps();
// max group_size is 128 (2^8)
// limited by 8640 to 2 (8640 % (2 * 32) == 0)
int group_num = 2;
for (int i = 0; i < nb; ++i) {
for (int i=0; i < n / (group_num * 32); i++){
__m256i laccu = _mm256_setzero_si256();
__m256i haccu = _mm256_setzero_si256();
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) );
for (int j=0; j < group_num; j++) {
__m256i xq8 = _mm256_set_epi32(
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 7]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 6]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 5]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 4]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 3]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 2]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 1]],
(int)i2s_i8s[x[i * group_num * 8 + j * 8 + 0]]
(int)q22_grid[x[i].qs[7]],
(int)q22_grid[x[i].qs[6]],
(int)q22_grid[x[i].qs[5]],
(int)q22_grid[x[i].qs[4]],
(int)q22_grid[x[i].qs[3]],
(int)q22_grid[x[i].qs[2]],
(int)q22_grid[x[i].qs[1]],
(int)q22_grid[x[i].qs[0]]
);
__m256i yq8 = _mm256_loadu_si256((const __m256i*)(y + i * group_num * 32 + j * 32));
__m256i yq8 = _mm256_loadu_si256((const __m256i*)(y[i].qs));
const __m256 q = mul_sum_i8_pairs_float(xq8, yq8);
__m128i hxq8 = _mm256_castsi256_si128(xq8);
__m128i lxq8 = _mm256_extractf128_si256(xq8, 1);
__m128i hyq8 = _mm256_castsi256_si128(yq8);
__m128i lyq8 = _mm256_extractf128_si256(yq8, 1);
__m256i hxq16 = _mm256_cvtepi8_epi16(hxq8);
__m256i lxq16 = _mm256_cvtepi8_epi16(lxq8);
__m256i hyq16 = _mm256_cvtepi8_epi16(hyq8);
__m256i lyq16 = _mm256_cvtepi8_epi16(lyq8);
__m256i hzq16 = _mm256_sign_epi16(hyq16, hxq16);
__m256i lzq16 = _mm256_sign_epi16(lyq16, lxq16);
haccu = _mm256_add_epi16(haccu, hzq16);
laccu = _mm256_add_epi16(laccu, lzq16);
acc = _mm256_fmadd_ps( d, q, acc );
}
__m256i hhzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(haccu));
__m256i hlzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(haccu, 1));
__m256i llzq32 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(laccu));
__m256i lhzq32 = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(laccu, 1));
accu = _mm256_add_epi32(accu, hhzq32);
accu = _mm256_add_epi32(accu, hlzq32);
accu = _mm256_add_epi32(accu, llzq32);
accu = _mm256_add_epi32(accu, lhzq32);
}
int sumi = hsum_i32_8(accu);
*s = (float)sumi;
*s = hsum_float_8(acc);
#else
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
int sumi = 0;
for (int i = 0; i < n / 4; i++) {
const int8_t* weight = (const int8_t *)(i2s_i8s + x[i]);
sumi += (int)y[i*4+0] * weight[0];
sumi += (int)y[i*4+1] * weight[1];
sumi += (int)y[i*4+2] * weight[2];
sumi += (int)y[i*4+3] * weight[3];
for (int j = 0; j < qk / 4; j++) {
const int8_t* weight = (const int8_t *)(q22_grid + x[i].qs[j]);
sumi += (int)y[i].qs[4*j+0] * weight[0];
sumi += (int)y[i].qs[4*j+1] * weight[1];
sumi += (int)y[i].qs[4*j+2] * weight[2];
sumi += (int)y[i].qs[4*j+3] * weight[3];
}
*s = (float)sumi;
sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d));
}
*s = sumf;
#endif
}
@ -14411,6 +14369,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
}
}
} break;
case GGML_TYPE_Q2_2:
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_q2_2, data, nb);
} break;
case GGML_TYPE_Q4_0:
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb);
@ -14509,7 +14471,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_I2_S:
// nothing to validate
break;
default:

View file

@ -12,6 +12,7 @@ extern "C" {
#endif
// Quantization
void quantize_row_q2_2_reference(const float * GGML_RESTRICT x, block_q2_2 * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
@ -32,6 +33,7 @@ void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@ -51,7 +53,6 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_i8_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, float* n);
// Dequantization
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@ -79,6 +80,7 @@ void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// Dot product
void ggml_vec_dot_q2_2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@ -100,7 +102,6 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_i2_i8_s (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
@ -118,12 +119,12 @@ size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q2_2(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_i2_s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
void iq2xs_init_impl(enum ggml_type type);
void iq2xs_free_impl(enum ggml_type type);

52
ggml.c
View file

@ -616,6 +616,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_F16,
.nrows = 1,
},
[GGML_TYPE_Q2_2] = {
.type_name = "q2_2",
.blck_size = QK2_2,
.type_size = sizeof(block_q2_2),
.is_quantized = true,
.from_float = quantize_row_q2_2,
.from_float_reference = (ggml_from_float_t) quantize_row_q2_2_reference,
.vec_dot = ggml_vec_dot_q2_2_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q4_0] = {
.type_name = "q4_0",
.blck_size = QK4_0,
@ -908,21 +919,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
.vec_dot_type = GGML_TYPE_BF16,
.nrows = 1,
},
[GGML_TYPE_I2_S] = {
.type_name = "i2_s",
.blck_size = 1,
.type_size = sizeof(int8_t),
.is_quantized = true,
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_i2_i8_s,
.vec_dot_type = GGML_TYPE_I8_S,
.nrows = 1,
},
[GGML_TYPE_I8_S] = {
.type_name = "i8_s",
.blck_size = 1,
.type_size = sizeof(int8_t),
.is_quantized = true,
}
};
@ -3071,9 +3067,6 @@ GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
if(tensor->type == GGML_TYPE_I2_S){
nbytes = nbytes / 4 + 32;
}
}
else {
nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
@ -12289,10 +12282,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
// 16 * 2, accounting for mmla kernels
float tmp[32];
// for per-tensor quant
const float * scale = (float * )((uint8_t*) (src0->data) + (ne00 * ne01 / 4));
const float * act_scales = (const float*) ((const char *) wdata + (ne11 * ne10));
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
@ -12325,13 +12314,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
if (src0->type == GGML_TYPE_I2_S) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01 / 4, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
tmp[ir0 - iir0] = tmp[ir0 - iir0] / (act_scales[i11]) * (*scale);
} else {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
}
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
@ -12494,17 +12478,12 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
if (src0->type == GGML_TYPE_I2_S) {
float* act_scales = (float*) ((char *) wdata + (ne11 * ne10));
quantize_row_i8_s((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (char *) wdata + ((i11*nb11 + i12*nb12 + i13*nb13) / 4), ne10, act_scales + i11);
} else {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
}
}
}
return;
}
@ -14189,6 +14168,7 @@ static void ggml_compute_forward_clamp(
} break;
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q2_2:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@ -14215,8 +14195,6 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
case GGML_TYPE_I2_S:
case GGML_TYPE_I8_S:
case GGML_TYPE_COUNT:
{
GGML_ASSERT(false);
@ -21340,6 +21318,7 @@ size_t ggml_quantize_chunk(
size_t result = 0;
switch (type) {
case GGML_TYPE_Q2_2: result = quantize_q2_2(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
@ -21359,7 +21338,6 @@ size_t ggml_quantize_chunk(
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_I2_S: result = quantize_i2_s (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_F16:
{
size_t elemsize = sizeof(ggml_fp16_t);
@ -21382,11 +21360,7 @@ size_t ggml_quantize_chunk(
assert(false);
}
if (type == GGML_TYPE_I2_S) {
result = nrows * row_size / 4 + 32;
} else {
GGML_ASSERT(result == nrows * row_size);
}
return result;
}

3
ggml.h
View file

@ -377,8 +377,7 @@ extern "C" {
GGML_TYPE_F64 = 28,
GGML_TYPE_IQ1_M = 29,
GGML_TYPE_BF16 = 30,
GGML_TYPE_I2_S = 31,
GGML_TYPE_I8_S = 32,
GGML_TYPE_Q2_2 = 31,
GGML_TYPE_COUNT,
};

View file

@ -202,6 +202,13 @@ class MODEL_TENSOR(IntEnum):
ATTN_KV_A_NORM = auto()
FFN_SUB_NORM = auto()
ATTN_SUB_NORM = auto()
ATTN_Q_SCALE = auto()
ATTN_K_SCALE = auto()
ATTN_V_SCALE = auto()
ATTN_OUT_SCALE = auto()
FFN_UP_SCALE = auto()
FFN_DOWN_SCALE = auto()
FFN_GATE_SCALE = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -293,6 +300,13 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
MODEL_TENSOR.ATTN_Q_SCALE: "blk.{bid}.attn_q_scale",
MODEL_TENSOR.ATTN_K_SCALE: "blk.{bid}.attn_k_scale",
MODEL_TENSOR.ATTN_V_SCALE: "blk.{bid}.attn_v_scale",
MODEL_TENSOR.ATTN_OUT_SCALE: "blk.{bid}.attn_output_scale",
MODEL_TENSOR.FFN_UP_SCALE: "blk.{bid}.ffn_up_scale",
MODEL_TENSOR.FFN_DOWN_SCALE: "blk.{bid}.ffn_down_scale",
MODEL_TENSOR.FFN_GATE_SCALE: "blk.{bid}.ffn_gate_scale",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -819,17 +833,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_SUB_NORM,
MODEL_TENSOR.FFN_SUB_NORM,
MODEL_TENSOR.ATTN_Q_SCALE,
MODEL_TENSOR.ATTN_K_SCALE,
MODEL_TENSOR.ATTN_V_SCALE,
MODEL_TENSOR.ATTN_OUT_SCALE,
MODEL_TENSOR.FFN_UP_SCALE,
MODEL_TENSOR.FFN_DOWN_SCALE,
MODEL_TENSOR.FFN_GATE_SCALE,
],
# TODO
}

View file

@ -498,6 +498,13 @@ enum llm_tensor {
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
LLM_TENSOR_FFN_SUB_NORM,
LLM_TENSOR_ATTN_Q_SCALE,
LLM_TENSOR_ATTN_K_SCALE,
LLM_TENSOR_ATTN_V_SCALE,
LLM_TENSOR_ATTN_OUTPUT_SCALE,
LLM_TENSOR_FFN_UP_SCALE,
LLM_TENSOR_FFN_DOWN_SCALE,
LLM_TENSOR_FFN_GATE_SCALE,
};
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
@ -1127,6 +1134,13 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_SUB_NORM, "blk.%d.ffn_sub_norm" },
{ LLM_TENSOR_ATTN_Q_SCALE, "blk.%d.attn_q_scale" },
{ LLM_TENSOR_ATTN_K_SCALE, "blk.%d.attn_q_scale" },
{ LLM_TENSOR_ATTN_V_SCALE, "blk.%d.attn_q_scale" },
{ LLM_TENSOR_ATTN_OUTPUT_SCALE, "blk.%d.attn_output_scale" },
{ LLM_TENSOR_FFN_UP_SCALE, "blk.%d.ffn_up_scale" },
{ LLM_TENSOR_FFN_DOWN_SCALE, "blk.%d.ffn_down_scale" },
{ LLM_TENSOR_FFN_GATE_SCALE, "blk.%d.ffn_gate_scale" },
},
},
{
@ -2075,6 +2089,15 @@ struct llama_layer {
// long rope factors
struct ggml_tensor * rope_long = nullptr;
struct ggml_tensor * rope_short = nullptr;
// bitnet scale
struct ggml_tensor * wq_scale;
struct ggml_tensor * wk_scale;
struct ggml_tensor * wv_scale;
struct ggml_tensor * wo_scale;
struct ggml_tensor * ffn_gate_scale;
struct ggml_tensor * ffn_up_scale;
struct ggml_tensor * ffn_down_scale;
};
struct llama_kv_cell {
@ -6460,16 +6483,23 @@ static bool llm_load_tensors(
layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wq_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_SCALE, "weight", i), {1});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wk_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K_SCALE, "weight", i), {1});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wv_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V_SCALE, "weight", i), {1});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.wo_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUTPUT_SCALE, "weight", i), {1});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_gate_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SCALE, "weight", i), {1});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
layer.ffn_down_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SCALE, "weight", i), {1});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.ffn_up_scale = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SCALE, "weight", i), {1});
}
} break;
default:
@ -11545,6 +11575,7 @@ struct llm_build_context {
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
@ -11553,6 +11584,7 @@ struct llm_build_context {
// B1.K
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
@ -11561,6 +11593,7 @@ struct llm_build_context {
// B1.V
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
@ -11659,6 +11692,7 @@ struct llm_build_context {
ggml_build_forward_expand(graph, cur_attn);
cur = ggml_mul_mat(ctx0, wo, cur_attn);
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
cb(cur, "kqv_out", il);
}
@ -11681,10 +11715,12 @@ struct llm_build_context {
cb(cur, "ffn_norm", il);
struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_up_scale);
cb(tmp, "ffn_up", il);
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur);
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_gate_scale);
cb(cur, "ffn_gate", il);
@ -11701,6 +11737,7 @@ struct llm_build_context {
cb(cur, "ffn_sub_norm", il);
cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur);
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
cb(cur, "ffn_down", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
@ -15444,6 +15481,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
llama_ftype ftype = params->ftype;
switch (params->ftype) {
case LLAMA_FTYPE_MOSTLY_Q2_2: default_type = GGML_TYPE_Q2_2; break;
case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
@ -15452,7 +15490,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break;
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
case LLAMA_FTYPE_MOSTLY_I2_S: default_type = GGML_TYPE_I2_S; break;
// K-quants
case LLAMA_FTYPE_MOSTLY_Q2_K_S:

View file

@ -156,7 +156,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
LLAMA_FTYPE_MOSTLY_I2_S = 33,
LLAMA_FTYPE_MOSTLY_Q2_2 = 33, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};