diff --git a/ggml.c b/ggml.c index 8b70806aa..12add19d8 100644 --- a/ggml.c +++ b/ggml.c @@ -607,10 +607,11 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { assert(k % QK == 0); const int nb = k / QK; + const size_t bs = 2*sizeof(float) + QK/2; - float * restrict pm = (float *) (y); - float * restrict pd = (float *) (pm + nb); - uint8_t * restrict pb = (uint8_t *) (pd + nb); + uint8_t * restrict pd = ((uint8_t *)y + 0*bs); + uint8_t * restrict pm = ((uint8_t *)y + 0*bs + sizeof(float)); + uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float)); uint8_t pp[QK/2]; @@ -627,8 +628,10 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { const float d = (max - min) / ((1 << 4) - 1); const float id = d ? 1.0f/d : 0.0f; - pm[i] = min; - pd[i] = d; + *(float *)pm = min; + *(float *)pd = d; + pm += bs; + pd += bs; for (int l = 0; l < QK; l += 2) { const float v0 = (x[i*QK + l + 0] - min)*id; @@ -643,7 +646,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { pp[l/2] = vi0 | (vi1 << 4); } - memcpy(pb + i*QK/2, pp, sizeof(pp)); + memcpy(pb, pp, sizeof(pp)); + pb += bs; } } @@ -687,16 +691,17 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) { assert(k % QK == 0); const int nb = k / QK; + const size_t bs = 2*sizeof(float) + QK/2; - const float * restrict pm = (const float *) (x); - const float * restrict pd = (const float *) (pm + nb); - const uint8_t * restrict pb = (const uint8_t *) (pd + nb); + const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs); + const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float)); + const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float)); for (int i = 0; i < nb; i++) { - const float m = pm[i]; - const float d = pd[i]; + const float d = *(const float *) (pd + i*bs); + const float m = *(const float *) (pm + i*bs); - const uint8_t * restrict pp = pb + i*QK/2; + const uint8_t * restrict pp = pb + i*bs; for (int l = 0; l < QK; l += 2) { const uint8_t vi = pp[l/2]; @@ -1584,14 +1589,16 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { const int nb = n / QK; - const float * restrict pm0 = (const float *) x; - const float * restrict pm1 = (const float *) y; + const size_t bs = 2*sizeof(float) + QK/2; - const float * restrict pd0 = (const float *) (pm0 + nb); - const float * restrict pd1 = (const float *) (pm1 + nb); + const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs); + const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs); - const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb); - const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb); + const uint8_t * restrict pm0 = ((const uint8_t *)x + 0*bs + sizeof(float)); + const uint8_t * restrict pm1 = ((const uint8_t *)y + 0*bs + sizeof(float)); + + const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + 2*sizeof(float)); + const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + 2*sizeof(float)); float sumf = 0.0; @@ -1604,14 +1611,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void // Main loop for (int i = 0; i < nb; ++i) { - const float * m0 = (const float *) (pm0 + i); - const float * m1 = (const float *) (pm1 + i); + const float * m0 = (const float *) (pm0 + i*bs); + const float * m1 = (const float *) (pm1 + i*bs); - const float * d0 = (const float *) (pd0 + i); - const float * d1 = (const float *) (pd1 + i); + const float * d0 = (const float *) (pd0 + i*bs); + const float * d1 = (const float *) (pd1 + i*bs); - const uint8_t * restrict p0 = pb0 + i*QK/2; - const uint8_t * restrict p1 = pb1 + i*QK/2; + const uint8_t * restrict p0 = pb0 + i*bs; + const uint8_t * restrict p1 = pb1 + i*bs; const __m256 d0v = _mm256_broadcast_ss( d0 ); const __m256 d1v = _mm256_broadcast_ss( d1 ); @@ -1677,14 +1684,14 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void #else // scalar for (int i = 0; i < nb; i++) { - const float m0 = pm0[i]; - const float m1 = pm1[i]; + const float * m0 = (const float *) (pm0 + i*bs); + const float * m1 = (const float *) (pm1 + i*bs); - const float d0 = pd0[i]; - const float d1 = pd1[i]; + const float * d0 = (const float *) (pd0 + i*bs); + const float * d1 = (const float *) (pd1 + i*bs); - const uint8_t * restrict p0 = pb0 + i*QK/2; - const uint8_t * restrict p1 = pb1 + i*QK/2; + const uint8_t * restrict p0 = pb0 + i*bs; + const uint8_t * restrict p1 = pb1 + i*bs; for (int j = 0; j < QK/2; j++) { const uint8_t v0 = p0[j]; diff --git a/utils.cpp b/utils.cpp index aa3ad1053..26e313d5f 100644 --- a/utils.cpp +++ b/utils.cpp @@ -489,7 +489,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist) { const int nb = k / qk; - const size_t row_size = nb*(2*sizeof(float) + sizeof(uint8_t)*qk/2); + const size_t bs = (2*sizeof(float) + sizeof(uint8_t)*qk/2); + const size_t row_size = nb*bs; assert(k % qk == 0); @@ -498,10 +499,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t char * pdst = (char *) dst; - for (int j = 0; j < n; j += k) { - float * pm = (float *) (pdst + (j/k)*row_size); - float * pd = (float *) (pm + nb); - uint8_t * pb = (uint8_t *) (pd + nb); + for (int j = 0; j < n; j += k) { + uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs); + uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float)); + uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float)); //printf("n = %d, k = %d, nb = %d, row_size = %d, j = %d, pm = %p, pd = %p, pb = %p\n", n, k, nb, row_size, j, pm, pd, pb); @@ -519,8 +520,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t const float d = (max - min) / ((1 << 4) - 1); const float id = d ? 1.0f/d : 0.0f; - pm[i] = min; - pd[i] = d; + *(float *) pd = d; + *(float *) pm = min; + pd += bs; + pm += bs; for (int l = 0; l < qk; l += 2) { const float v0 = (src[j + i*qk + l + 0] - min)*id; @@ -538,7 +541,8 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t pp[l/2] = vi0 | (vi1 << 4); } - memcpy(pb + i*qk/2, pp, pp_size); + memcpy(pb, pp, pp_size); + pb += bs; } } }