From 675425563c7e151c268daa6a9bed144d4991a677 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Jul 2023 21:16:10 +0300 Subject: [PATCH 1/6] ggml : poc for normalizing weights for better quantization --- ggml-cuda.cu | 30 ++++++++----- ggml.c | 100 ++++++++++++++++++++++++----------------- ggml.h | 6 +-- llama.cpp | 125 ++++++++++++++++++++++++++++++++++++++++----------- 4 files changed, 180 insertions(+), 81 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d31fc79c1..371c9e84a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -74,14 +74,17 @@ typedef void (*ggml_cuda_op_t)( // QR = QK / number of values before dequantization // QI = number of 32 bit integers before dequantization +#define Q4_0DM (1.0f/8.0f) +#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f) + #define QK4_0 32 #define QR4_0 2 #define QI4_0 (QK4_0 / (4 * QR4_0)) typedef struct { - half d; // delta + int8_t d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); +static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 #define QR4_1 2 @@ -103,16 +106,20 @@ typedef struct { } block_q5_0; static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); +#define Q5_1DM (2.0f/31.0f) +#define Q5_1MM (2.0f ) +#define Q5_1D(x) ( (((x) & 0x0F)*Q5_1DM) / 15.0f) +#define Q5_1M(x) (-1.0f + (((x) >> 4)*Q5_1MM) / 15.0f) + #define QK5_1 32 #define QR5_1 2 #define QI5_1 (QK5_1 / (4 * QR5_1)) typedef struct { - half d; // delta - half m; // min + uint8_t dm; // 4-bit delta + 4-bit min uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; -static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); +static_assert(sizeof(block_q5_1) == sizeof(uint8_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); #define QK8_0 32 #define QR8_0 1 @@ -360,7 +367,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; - const dfloat d = x[ib].d; + const dfloat d = Q4_0D(x[ib].d); const int vui = x[ib].qs[iqs]; @@ -422,8 +429,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q5_1 * x = (const block_q5_1 *) vx; - const dfloat d = x[ib].d; - const dfloat m = x[ib].m; + const dfloat d = Q5_1D(x[ib].dm); + const dfloat m = Q5_1M(x[ib].dm); uint32_t qh; memcpy(&qh, x[ib].qh, sizeof(qh)); @@ -1336,7 +1343,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]); - const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d); + const float d = Q4_0D(bq4_0->d) * __half2float(bq8_1->d); // subtract 8 from each quantized value const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808); @@ -1419,14 +1426,15 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + // TODO: fix misaligned access const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]); const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2); const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2); const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]); - const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d); - const float m = bq5_1->m; + const float d = Q5_1D(bq5_1->dm) * __half2float(bq8_1->d); + const float m = Q5_1M(bq5_1->dm); const float s = bq8_1->s; int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits diff --git a/ggml.c b/ggml.c index b77f99267..b338cbf50 100644 --- a/ggml.c +++ b/ggml.c @@ -892,12 +892,16 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { #endif #endif +// we know the values are in the [-1 .. 1] range, so abs(d) cannot be more than 1/8 when using 4 bits +#define Q4_0DM (1.0f/8.0f) +#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f) + #define QK4_0 32 typedef struct { - ggml_fp16_t d; // delta + int8_t d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); +static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 typedef struct { @@ -915,14 +919,21 @@ typedef struct { } block_q5_0; static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); +// we know the values are in the [-1 .. 1] range, so: +// - d is unsigned 4-bit that represents maximum value of 2.0/31 when using 5 bits +// - m is unsigned 4-bit that represents offset from -1.0 which cannot be more than 2.0 +#define Q5_1DM (2.0f/31.0f) +#define Q5_1MM (2.0f ) +#define Q5_1D(x) ( (((x) & 0x0F)*Q5_1DM) / 15.0f) +#define Q5_1M(x) (-1.0f + (((x) >> 4)*Q5_1MM) / 15.0f) + #define QK5_1 32 typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min + uint8_t dm; // 4-bit delta + 4-bit min uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; -static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); +static_assert(sizeof(block_q5_1) == sizeof(uint8_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); #define QK8_0 32 typedef struct { @@ -959,10 +970,13 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r } } - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; + float d = max / -8; - y[i].d = GGML_FP32_TO_FP16(d); + y[i].d = (int8_t)(ceilf((127.0f * d) / Q4_0DM)); + + d = Q4_0D(y[i].d); + + const float id = d ? 1.0f/d : 0.0f; for (int j = 0; j < qk/2; ++j) { const float x0 = x[i*qk + 0 + j]*id; @@ -1088,11 +1102,17 @@ static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * r if (v > max) max = v; } - const float d = (max - min) / ((1 << 5) - 1); - const float id = d ? 1.0f/d : 0.0f; + y[i].dm = (uint8_t)(floorf((15.0f * (min + 1.0f)) / Q5_1MM)) << 4; - y[i].d = GGML_FP32_TO_FP16(d); - y[i].m = GGML_FP32_TO_FP16(min); + min = Q5_1M(y[i].dm); + + float d = (max - min) / ((1 << 5) - 1); + + y[i].dm |= (uint8_t)(ceilf((15.0f * d) / Q5_1DM)); + + d = Q5_1D(y[i].dm); + + const float id = d ? 1.0f/d : 0.0f; uint32_t qh = 0; @@ -1530,7 +1550,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict const int nb = k / qk; for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); + const float d = Q4_0D(x[i].d); for (int j = 0; j < qk/2; ++j) { const int x0 = (x[i].qs[j] & 0x0F) - 8; @@ -1597,8 +1617,8 @@ static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict const int nb = k / qk; for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - const float m = GGML_FP16_TO_FP32(x[i].m); + const float d = Q5_1D(x[i].dm); + const float m = Q5_1M(x[i].dm); uint32_t qh; memcpy(&qh, x[i].qh, sizeof(qh)); @@ -2407,8 +2427,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), Q4_0D(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), Q4_0D(x1->d)*GGML_FP16_TO_FP32(y1->d)); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); @@ -2425,8 +2445,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), Q4_0D(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), Q4_0D(x1->d)*GGML_FP16_TO_FP32(y1->d)); #endif } @@ -2438,7 +2458,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + const __m256 d = _mm256_set1_ps( Q4_0D(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); __m256i bx = bytes_from_nibbles_32(x[i].qs); @@ -2462,7 +2482,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + const __m256 d = _mm256_set1_ps( Q4_0D(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); const __m128i lowMask = _mm_set1_epi8(0xF); const __m128i off = _mm_set1_epi8(8); @@ -2504,7 +2524,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); + const __m128 d_0_1 = _mm_set1_ps( Q4_0D(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); @@ -2522,7 +2542,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); + const __m128 d_2_3 = _mm_set1_ps( Q4_0D(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); @@ -2555,7 +2575,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); + const __m128 d_0_1 = _mm_set1_ps( Q4_0D(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); @@ -2573,7 +2593,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); + const __m128 d_2_3 = _mm_set1_ps( Q4_0D(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); @@ -2621,7 +2641,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); } - sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); + sumf += sumi*Q4_0D(x[i].d)*GGML_FP16_TO_FP32(y[i].d); } *s = sumf; @@ -3026,8 +3046,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const uint8x16_t m4b = vdupq_n_u8(0x0F); - summs0 += GGML_FP16_TO_FP32(x0->m) * y0->s; - summs1 += GGML_FP16_TO_FP32(x1->m) * y1->s; + summs0 += Q5_1M(x0->dm) * y0->s; + summs1 += Q5_1M(x1->dm) * y1->s; // extract the 5th bit via lookup table ((b) << 4) memcpy(&qh0, x0->qh, sizeof(qh0)); @@ -3072,10 +3092,10 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d); + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), Q5_1D(x0->dm)*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d); + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), Q5_1D(x1->dm)*y1->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); @@ -3092,8 +3112,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), Q5_1D(x0->dm)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), Q5_1D(x1->dm)*y1->d); #endif } @@ -3111,7 +3131,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const block_q5_1 * restrict x0 = &x[i]; const block_q8_1 * restrict y0 = &y[i]; - summs += GGML_FP16_TO_FP32(x0->m) * y0->s; + summs += Q5_1M(x0->dm) * y0->s; const v128_t m4b = wasm_i8x16_splat(0x0F); @@ -3158,7 +3178,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * wasm_i32x4_dot_i16x8(v0lfh, v1lh)), wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d))); + wasm_f32x4_splat(Q5_1D(x0->dm) * y0->d))); } *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + @@ -3171,9 +3191,9 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); + const __m256 dx = _mm256_set1_ps(Q5_1D(x[i].dm)); - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + summs += Q5_1M(x[i].dm) * y[i].s; __m256i bx = bytes_from_nibbles_32(x[i].qs); __m256i bxhi = bytes_from_bits_32(x[i].qh); @@ -3198,9 +3218,9 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)); + const __m256 dx = _mm256_set1_ps(Q5_1D(x[i].dm)); - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + summs += Q5_1M(x[i].dm) * y[i].s; __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i bxhi = bytes_from_bits_32(x[i].qh); @@ -3243,7 +3263,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + sumf += (Q5_1D(x[i].dm)*y[i].d)*sumi + Q5_1M(x[i].dm)*y[i].s; } *s = sumf; @@ -5470,7 +5490,7 @@ struct ggml_tensor * ggml_sum_rows( } int64_t ne[4] = {1,1,1,1}; - for (int i=1; in_dims; ++i) { + for (int i = 1; i < a->n_dims; ++i) { ne[i] = a->ne[i]; } diff --git a/ggml.h b/ggml.h index 9919cce7c..854f79437 100644 --- a/ggml.h +++ b/ggml.h @@ -281,9 +281,9 @@ extern "C" { GGML_TYPE_Q5_K = 13, GGML_TYPE_Q6_K = 14, GGML_TYPE_Q8_K = 15, - GGML_TYPE_I8, - GGML_TYPE_I16, - GGML_TYPE_I32, + GGML_TYPE_I8 = 16, + GGML_TYPE_I16 = 17, + GGML_TYPE_I32 = 18, GGML_TYPE_COUNT, }; diff --git a/llama.cpp b/llama.cpp index 9a8ecdcf6..682a07a17 100644 --- a/llama.cpp +++ b/llama.cpp @@ -119,7 +119,7 @@ static const std::map & MEM_REQ_SCRATCH1() { static std::map k_sizes = { { MODEL_3B, 128ull * MB }, - { MODEL_7B, 160ull * MB }, + { MODEL_7B, 200ull * MB }, { MODEL_13B, 192ull * MB }, { MODEL_30B, 256ull * MB }, { MODEL_65B, 384ull * MB }, // guess @@ -229,6 +229,11 @@ struct llama_layer { struct ggml_tensor * wv; struct ggml_tensor * wo; + struct ggml_tensor * wq_a; + struct ggml_tensor * wk_a; + struct ggml_tensor * wv_a; + struct ggml_tensor * wo_a; + // normalization struct ggml_tensor * ffn_norm; @@ -236,6 +241,10 @@ struct llama_layer { struct ggml_tensor * w1; struct ggml_tensor * w2; struct ggml_tensor * w3; + + struct ggml_tensor * w1_a; + struct ggml_tensor * w2_a; + struct ggml_tensor * w3_a; }; struct llama_kv_cache { @@ -1208,17 +1217,29 @@ static void llama_model_load_internal( layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split); layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); + layer.wq_a = ml->get_tensor(layers_i + ".attention.wq.weight.a", {n_embd}, backend); + layer.wk_a = ml->get_tensor(layers_i + ".attention.wk.weight.a", {n_embd_gqa}, backend); + layer.wv_a = ml->get_tensor(layers_i + ".attention.wv.weight.a", {n_embd_gqa}, backend); + layer.wo_a = ml->get_tensor(layers_i + ".attention.wo.weight.a", {n_embd}, backend); + layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); + layer.w1_a = ml->get_tensor(layers_i + ".feed_forward.w1.weight.a", { n_ff}, backend); + layer.w2_a = ml->get_tensor(layers_i + ".feed_forward.w2.weight.a", {n_embd}, backend); + layer.w3_a = ml->get_tensor(layers_i + ".feed_forward.w3.weight.a", { n_ff}, backend); + if (backend == GGML_BACKEND_GPU) { vram_weights += - ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + - ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + - ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3) + + ggml_nbytes(layer.wq_a) + ggml_nbytes(layer.wk_a) + ggml_nbytes(layer.wv_a) + + ggml_nbytes(layer.wo_a) + ggml_nbytes(layer.w1_a) + ggml_nbytes(layer.w2_a) + + ggml_nbytes(layer.w3_a); } } } @@ -1360,6 +1381,34 @@ static bool llama_model_load( } } +// computes: Z = (X @ Y) * a +// a is vector with size equal to rows of X. each element is the scaling factor used to normalize X's rows +// the ggml_mul() is broadcasted row-wise to restore the normalization +struct ggml_tensor * ggml_mul_mat_ex( + struct ggml_context * ctx0, + struct ggml_tensor * t, + struct ggml_tensor * a, + //struct ggml_tensor * b, + struct ggml_tensor * cur, + offload_func_t offload_func) { + cur = ggml_mul_mat(ctx0, t, cur); + offload_func(cur); + + cur = ggml_mul(ctx0, cur, a); + offload_func(cur); + + return cur; + + //struct ggml_tensor * tmp = ggml_mul_mat(ctx0, t, cur); + //tmp = ggml_mul(ctx0, tmp, a); + //cur = ggml_add(ctx0, tmp, + // ggml_mul(ctx0, + // ggml_repeat(ctx0, ggml_sum_rows(ctx0, cur), tmp), + // b) + // ); + //return cur; +} + // evaluate the transformer // // - lctx: llama context @@ -1502,12 +1551,10 @@ static bool llama_eval_internal( // self-attention { // compute Q and K and RoPE them - struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); - offload_func_kq(tmpk); + struct ggml_tensor * tmpk = ggml_mul_mat_ex(ctx0, model.layers[il].wk, model.layers[il].wk_a, cur, offload_func_kq); ggml_set_name(tmpk, "tmpk"); - struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); - offload_func_kq(tmpq); + struct ggml_tensor * tmpq = ggml_mul_mat_ex(ctx0, model.layers[il].wq, model.layers[il].wq_a, cur, offload_func_kq); ggml_set_name(tmpq, "tmpq"); struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); @@ -1522,8 +1569,7 @@ static bool llama_eval_internal( { // compute the transposed [N, n_embd] V matrix - struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); - offload_func_v(tmpv); + struct ggml_tensor * tmpv = ggml_mul_mat_ex(ctx0, model.layers[il].wv, model.layers[il].wv_a, cur, offload_func_v); ggml_set_name(tmpv, "tmpv"); struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); @@ -1620,10 +1666,7 @@ static bool llama_eval_internal( ggml_set_name(cur, "KQV_merged_contiguous"); // projection (no bias) - cur = ggml_mul_mat(ctx0, - model.layers[il].wo, - cur); - offload_func(cur); + cur = ggml_mul_mat_ex(ctx0, model.layers[il].wo, model.layers[il].wo_a, cur, offload_func); ggml_set_name(cur, "result_wo"); } @@ -1647,16 +1690,10 @@ static bool llama_eval_internal( ggml_set_name(cur, "ffn_norm"); } - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, - model.layers[il].w3, - cur); - offload_func(tmp); + struct ggml_tensor * tmp = ggml_mul_mat_ex(ctx0, model.layers[il].w3, model.layers[il].w3_a, cur, offload_func); ggml_set_name(tmp, "result_w3"); - cur = ggml_mul_mat(ctx0, - model.layers[il].w1, - cur); - offload_func(cur); + cur = ggml_mul_mat_ex(ctx0, model.layers[il].w1, model.layers[il].w1_a, cur, offload_func); ggml_set_name(cur, "result_w1"); // SILU activation @@ -1668,10 +1705,7 @@ static bool llama_eval_internal( offload_func(cur); ggml_set_name(cur, "silu_x_result_w3"); - cur = ggml_mul_mat(ctx0, - model.layers[il].w2, - cur); - offload_func(cur); + cur = ggml_mul_mat_ex(ctx0, model.layers[il].w2, model.layers[il].w2_a, cur, offload_func); ggml_set_name(cur, "result_w2"); } @@ -2936,7 +2970,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else { new_type = quantized_type; #ifdef GGML_USE_K_QUANTS - if (tensor.name == "output.weight") { + if (tensor.name == "output.weight" || tensor.name == "tok_embeddings.weight") { int nx = tensor.ne.at(0); int ny = tensor.ne.at(1); if (nx % QK_K == 0 && ny % QK_K == 0) { @@ -2997,6 +3031,43 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s f32_data = (float *) f32_conv_buf.addr; } + // TODO: this is temporary since we only implemented Q4_0 and Q5_1 as POC + if (new_type == GGML_TYPE_Q4_0 || new_type == GGML_TYPE_Q5_1) { + //printf("\n dims: %d x %d\n", tensor.ne.at(0), tensor.ne.at(1)); + + const uint32_t nr = tensor.ne.at(1); + + std::vector va(nr); + std::vector vb(nr); + + // normalize to -1..1 per rows + for (uint32_t r = 0; r < nr; ++r) { + const uint32_t n = tensor.ne.at(0); + float * p = f32_data + r * n; + + float amax = 0.0f; + for (size_t i = 0; i < n; ++i) { + amax = std::max(amax, std::abs(p[i])); + } + + for (size_t i = 0; i < n; ++i) { + p[i] = p[i] / amax; + } + + va[r] = amax; + } + + { + llama_load_tensor ta; + ta.name = tensor.name + ".a"; + ta.type = GGML_TYPE_F32; + ta.ne = std::vector(1, nr); + ta.size = nr * sizeof(float); + ta.data = (uint8_t *) va.data(); + file_saver.write_tensor(ta, GGML_TYPE_F32, ta.data, ta.size); + } + } + printf("quantizing to %s .. ", ggml_type_name(new_type)); fflush(stdout); From a4d1eb72c6a4b8531c483f631cb6259c3cd3b7af Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 28 Jul 2023 14:37:52 +0300 Subject: [PATCH 2/6] ggml : add q4_1 normalized quants --- ggml-cuda.cu | 20 ++++++++++++-------- ggml.c | 48 +++++++++++++++++++++++++++++------------------- llama.cpp | 4 ++-- 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 371c9e84a..0638db693 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -86,15 +86,19 @@ typedef struct { } block_q4_0; static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); +#define Q4_1DM (2.0f/15.0f) +#define Q4_1MM (2.0f ) +#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f) +#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f) + #define QK4_1 32 #define QR4_1 2 #define QI4_1 (QK4_1 / (4 * QR4_1)) typedef struct { - half d; // delta - half m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants + uint16_t dm; // 8-bit delta + 8-bit min (can be adjusted easily) + uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == sizeof(uint16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 #define QR5_0 2 @@ -386,8 +390,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ const block_q4_1 * x = (const block_q4_1 *) vx; - const dfloat d = x[ib].d; - const dfloat m = x[ib].m; + const dfloat d = Q4_1D(x[ib].dm); + const dfloat m = Q4_1M(x[ib].dm); const int vui = x[ib].qs[iqs]; @@ -1368,8 +1372,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]); - const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d); - const float m = bq4_1->m; + const float d = Q4_1D(bq4_1->dm) * __half2float(bq8_1->d); + const float m = Q4_1M(bq4_1->dm); const float s = bq8_1->s; const int vi0 = (vi >> 0) & 0x0F0F0F0F; diff --git a/ggml.c b/ggml.c index b338cbf50..d478dd557 100644 --- a/ggml.c +++ b/ggml.c @@ -903,13 +903,17 @@ typedef struct { } block_q4_0; static_assert(sizeof(block_q4_0) == sizeof(int8_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); +#define Q4_1DM (2.0f/15.0f) +#define Q4_1MM (2.0f ) +#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f) +#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f) + #define QK4_1 32 typedef struct { - ggml_fp16_t d; // delta - ggml_fp16_t m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants + uint16_t dm; // 8-bit delta + 8-bit min (can be adjusted easily) + uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == sizeof(uint16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 typedef struct { @@ -929,7 +933,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5 #define QK5_1 32 typedef struct { - uint8_t dm; // 4-bit delta + 4-bit min + uint8_t dm; // 4-bit delta + 4-bit min (can be adjusted easily) uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; @@ -1013,11 +1017,17 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r if (v > max) max = v; } - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; + y[i].dm = (uint16_t)(floorf((255.0f * (min + 1.0f)) / Q4_1MM)) << 8; - y[i].d = GGML_FP32_TO_FP16(d); - y[i].m = GGML_FP32_TO_FP16(min); + min = Q4_1M(y[i].dm); + + float d = (max - min) / ((1 << 4) - 1); + + y[i].dm |= (uint16_t)(ceilf((255.0f * d) / Q4_1DM)); + + d = Q4_1D(y[i].dm); + + const float id = d ? 1.0f/d : 0.0f; for (int j = 0; j < qk/2; ++j) { const float x0 = (x[i*qk + 0 + j] - min)*id; @@ -1570,8 +1580,8 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict const int nb = k / qk; for (int i = 0; i < nb; i++) { - const float d = GGML_FP16_TO_FP32(x[i].d); - const float m = GGML_FP16_TO_FP32(x[i].m); + const float d = Q4_1D(x[i].dm); + const float m = Q4_1M(x[i].dm); for (int j = 0; j < qk/2; ++j) { const int x0 = (x[i].qs[j] & 0x0F); @@ -2671,7 +2681,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const block_q8_1 * restrict y0 = &y[i + 0]; const block_q8_1 * restrict y1 = &y[i + 1]; - summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; + summs += Q4_1M(x0->dm) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -2695,8 +2705,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), Q4_1D(x0->dm)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), Q4_1D(x1->dm)*y1->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); @@ -2713,8 +2723,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), Q4_1D(x0->dm)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), Q4_1D(x1->dm)*y1->d); #endif } @@ -2727,10 +2737,10 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d0 = Q4_1D(x[i].dm); const float d1 = y[i].d; - summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; + summs += Q4_1M(x[i].dm) * y[i].s; const __m256 d0v = _mm256_set1_ps( d0 ); const __m256 d1v = _mm256_set1_ps( d1 ); @@ -2767,7 +2777,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + sumf += (Q4_1D(x[i].dm)*y[i].d)*sumi + Q4_1M(x[i].dm)*y[i].s; } *s = sumf; diff --git a/llama.cpp b/llama.cpp index 682a07a17..663586f77 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3031,8 +3031,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s f32_data = (float *) f32_conv_buf.addr; } - // TODO: this is temporary since we only implemented Q4_0 and Q5_1 as POC - if (new_type == GGML_TYPE_Q4_0 || new_type == GGML_TYPE_Q5_1) { + // TODO: this is temporary since we only implemented Q4_0, Q4_1 and Q5_1 as PoC + if (new_type == GGML_TYPE_Q4_0 || new_type == GGML_TYPE_Q4_1 || new_type == GGML_TYPE_Q5_1) { //printf("\n dims: %d x %d\n", tensor.ne.at(0), tensor.ne.at(1)); const uint32_t nr = tensor.ne.at(1); From e5d23f2e7e2dd952f0964b26da89fe619d2b025c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 28 Jul 2023 16:31:59 +0300 Subject: [PATCH 3/6] ggml : fix ARM build + speed-up ggml_mul --- ggml.c | 21 ++++++++++----------- ggml.h | 6 +++--- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/ggml.c b/ggml.c index d478dd557..63f2fc3c0 100644 --- a/ggml.c +++ b/ggml.c @@ -2681,7 +2681,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const block_q8_1 * restrict y0 = &y[i + 0]; const block_q8_1 * restrict y1 = &y[i + 1]; - summs += Q4_1M(x0->dm) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; + summs += Q4_1M(x0->dm) * y0->s + Q4_1M(x1->dm) * y1->s; const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -8898,6 +8898,13 @@ static void ggml_compute_forward_mul_f32( const int64_t nr = ggml_nrows(src0); + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + GGML_TENSOR_BINARY_OP_LOCALS; GGML_ASSERT( nb0 == sizeof(float)); @@ -8905,7 +8912,7 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT(ne00 == ne10); if (nb10 == sizeof(float)) { - for (int64_t ir = ith; ir < nr; ir += nth) { + for (int64_t ir = ir0; ir < ir1; ++ir) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); const int64_t i02 = (ir - i03*ne02*ne01)/ne01; @@ -8919,19 +8926,11 @@ static void ggml_compute_forward_mul_f32( float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); -#ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_mul_f32); - - vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); -#else ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr); -#endif - // } - // } } } else { // src1 is not contiguous - for (int64_t ir = ith; ir < nr; ir += nth) { + for (int64_t ir = ir0; ir < ir1; ++ir) { // src0 and dst are same shape => same indices // src1 is broadcastable across src0 and dst in i1, i2, i3 const int64_t i03 = ir/(ne02*ne01); diff --git a/ggml.h b/ggml.h index 854f79437..9919cce7c 100644 --- a/ggml.h +++ b/ggml.h @@ -281,9 +281,9 @@ extern "C" { GGML_TYPE_Q5_K = 13, GGML_TYPE_Q6_K = 14, GGML_TYPE_Q8_K = 15, - GGML_TYPE_I8 = 16, - GGML_TYPE_I16 = 17, - GGML_TYPE_I32 = 18, + GGML_TYPE_I8, + GGML_TYPE_I16, + GGML_TYPE_I32, GGML_TYPE_COUNT, }; From 72af25998c65e2fc0affb57347323d58a9781a12 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 28 Jul 2023 17:12:27 +0300 Subject: [PATCH 4/6] Fix misaligned memory access in Q4_1 kernel --- ggml-cuda.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0638db693..cc874d6be 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1368,7 +1368,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; - const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]); + int vi; + memcpy(&vi, &bq4_1->qs[sizeof(int) * (iqs + 0)], sizeof(vi)); + //const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]); const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]); From dead8f4b5b801356c9d9f7891aa6fc8068597cff Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 28 Jul 2023 17:27:01 +0300 Subject: [PATCH 5/6] Fix misaligned memory access in Q4_1 kernel --- ggml-cuda.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index cc874d6be..13e21bcd0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1433,7 +1433,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; // TODO: fix misaligned access - const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]); + int qs; + memcpy(&qs, &bq5_1->qs[sizeof(int) * (iqs + 0)], sizeof(qs)); + //const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]); const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2); const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2); const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); From b4e70822f6282a6b3c9ae53a282d30c9d5ccf70f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 30 Aug 2023 18:32:43 +0300 Subject: [PATCH 6/6] metal : add poc for normalized Q4_0 and Q4_1 --- ggml-metal.m | 7 +++-- ggml-metal.metal | 67 +++++++++++++++++++++++++++++------------------- 2 files changed, 45 insertions(+), 29 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index e929c4b07..1aaff6a93 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -697,6 +697,9 @@ void ggml_metal_graph_compute( } break; case GGML_OP_MUL: { + GGML_ASSERT(ne00 % 4 == 0); + const int64_t nb = ne00/4; + if (ggml_nelements(src1) == ne10) { // src1 is a row [encoder setComputePipelineState:ctx->pipeline_mul_row]; @@ -706,9 +709,9 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; - const int64_t n = ggml_nelements(dst); + const int64_t n = ggml_nelements(dst)/4; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index 82e1a0c7a..bfb32eccd 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4,17 +4,22 @@ using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define Q4_0DM (1.0f/8.0f) +#define Q4_0D(x) (((x)*Q4_0DM) / 127.0f) #define QK4_0 32 #define QR4_0 2 typedef struct { - half d; // delta + int8_t d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; +#define Q4_1DM (2.0f/15.0f) +#define Q4_1MM (2.0f ) +#define Q4_1D(x) ( (((x) & 0xFF)*Q4_1DM) / 255.0f) +#define Q4_1M(x) (-1.0f + (((x) >> 8)*Q4_1MM) / 255.0f) #define QK4_1 32 typedef struct { - half d; // delta - half m; // min + uint16_t dm; uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; @@ -44,9 +49,9 @@ kernel void kernel_add_row( } kernel void kernel_mul( - device const float * src0, - device const float * src1, - device float * dst, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src1[tpig]; } @@ -54,12 +59,12 @@ kernel void kernel_mul( // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_mul_row( - device const float * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant int64_t & nb, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % ne00]; + dst[tpig] = src0[tpig] * src1[tpig % nb]; } kernel void kernel_scale( @@ -314,14 +319,18 @@ kernel void kernel_rms_norm( // we assume that the yl's have been multiplied with the appropriate scale factor // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; + float d = Q4_0D(qb_curr->d); float2 acc = 0.f; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + device const uint8_t * qs = ((device const uint8_t *)qb_curr->qs + il); + uint16_t qs16; for (int i = 0; i < 8; i+=2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); + qs16 = qs[i+1]; + qs16 <<= 8; + qs16 |= qs[i]; + acc[0] += yl[i + 0] * (qs16 & 0x000F) + + yl[i + 1] * (qs16 & 0x0F00); + acc[1] += yl[i + 8] * (qs16 & 0x00F0) + + yl[i + 9] * (qs16 & 0xF000); } return d * (sumy * -8.f + acc[0] + acc[1]); } @@ -331,9 +340,9 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre // we assume that the yl's have been multiplied with the appropriate scale factor // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { - float d = qb_curr->d; - float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float d = Q4_1D(qb_curr->dm); + float m = Q4_1M(qb_curr->dm); + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); float2 acc = 0.f; for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) @@ -1686,23 +1695,27 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const half d = il ? (xb->d / 16.h) : xb->d; + device const uint8_t * qs = ((device const uint8_t *)xb->qs); + const half d = il ? (Q4_0D(xb->d) / 16.h) : Q4_0D(xb->d); const half m = il ? ( -8.h * 16.h) : -8.h; const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask1 = il ? 0xF000 : 0x0F00; + uint16_t qs16; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; - reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; + qs16 = qs[2*i+1]; + qs16 <<= 8; + qs16 |= qs[2*i]; + reg[i/2][2*(i%2)] = (((qs16 & mask0) ) + m) * d; + reg[i/2][2*(i%2)+1] = (((qs16 & mask1) >> 8) + m) * d; } } template void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { - device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const half d = il ? (xb->d / 16.h) : xb->d; - const half m = xb->m; + device const uint16_t * qs = ((device const uint16_t *)xb + 1); + const half d = il ? (Q4_1D(xb->dm) / 16.h) : Q4_1D(xb->dm); + const half m = Q4_1M(xb->dm); const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask1 = il ? 0xF000 : 0x0F00;