clean code

This commit is contained in:
root 2024-06-09 20:22:03 +08:00
parent dbee0a86c1
commit 1c5a8b7fec
7 changed files with 64 additions and 216 deletions

View file

@ -659,6 +659,24 @@ static inline __m128i packNibbles( __m256i bytes ) {
}
#endif //__loongarch_asx
void quantize_row_i8_s(const float * x, void * y, int64_t n, float* act_scales) {
int8_t* dst = (int8_t*)y;
double min = 0.00001;
double max = min;
for (int i = 0; i < n; ++i) {
max = MAX(max, (double)fabs(x[i]));
}
float s = 127 / max;
act_scales[0] = s;
float temp;
for (int i = 0; i < n; ++i) {
temp = round(x[i] * s);
if (temp > 127) temp = 127;
if (temp < -128) temp = -128;
dst[i] = (int8_t)(temp);
}
}
// reference implementation for deterministic creation of model files
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
static const int qk = QK4_0;
@ -3308,7 +3326,9 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
// 2 bits per weight
size_t row_size = ggml_row_size(GGML_TYPE_I2, n_per_row);
UNUSED(quant_weights);
size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row);
int n = nrow * n_per_row;
@ -3326,7 +3346,7 @@ size_t quantize_i2_s(const float * restrict src, void * restrict dst, int64_t nr
q8[i] = 0;
continue;
}
q8[i] = src[i] * i2_scale > 0 ? 1 : 3;
q8[i] = (double)src[i] * i2_scale > 0 ? 1 : 3;
}
// q8 -> 0, 1, 3
@ -3773,14 +3793,19 @@ static inline __m128i get_scale_shuffle(int i) {
//====================================== I2 ===============================================
void ggml_vec_dot_i2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
void ggml_vec_dot_i2_i8_s(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const uint8_t * restrict x = vx;
const int8_t * restrict y = vy;
UNUSED(bs);
UNUSED(bx);
UNUSED(by);
UNUSED(nrc);
int sumi = 0;
for (int i = 0; i < n / 4; i++) {
int8_t* weight = (const int8_t *)(i2_q8 + x[i]);
const int8_t* weight = (const int8_t *)(i2_q8 + x[i]);
sumi += (int)y[i*4+0] * weight[0];
sumi += (int)y[i*4+1] * weight[1];
sumi += (int)y[i*4+2] * weight[2];
@ -14431,7 +14456,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_I16:
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_I2:
case GGML_TYPE_I2_S:
// nothing to validate
break;
default: