diff --git a/examples/alpaca.sh b/examples/alpaca.sh index 8d6261730..aef207f36 100755 --- a/examples/alpaca.sh +++ b/examples/alpaca.sh @@ -7,4 +7,13 @@ cd `dirname $0` cd .. -./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt --ctx_size 2048 -n -1 -ins -b 256 --top_k 10000 --temp 0.2 --repeat_penalty 1 -t 7 +./main -m ./models/ggml-alpaca-7b-q4.bin \ + --color \ + -f ./prompts/alpaca.txt \ + --ctx_size 2048 \ + -n -1 \ + -ins -b 256 \ + --top_k 10000 \ + --temp 0.2 \ + --repeat_penalty 1.1 \ + -t 7 diff --git a/ggml.c b/ggml.c index 0ee72bc84..8067fb703 100644 --- a/ggml.c +++ b/ggml.c @@ -452,6 +452,24 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi) return bytes; } +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + #if __AVX2__ || __AVX512F__ // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval @@ -472,6 +490,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) return bytes; } +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +} + static inline __m128i packNibbles( __m256i bytes ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh @@ -622,10 +658,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong #define QK8_0 32 typedef struct { float d; // delta - float s; // d * sum(qs[i]) + float s0; // d * sum(qs[i]) low + float s1; // d * sum(qs[i]) high int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); // reference implementation for deterministic creation of model files @@ -1265,39 +1302,25 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r y[i].d = d; - int sum = 0; - for (int l = 0; l < QK8_0; ++l) { - const float v = x[i*QK8_0 + l]*id; - y[i].qs[l] = roundf(v); - sum += y[i].qs[l]; + int sum0 = 0; + int sum1 = 0; + + for (int l = 0; l < QK8_0/2; ++l) { + const float v0 = x[i*QK8_0 + l]*id; + const float v1 = x[i*QK8_0 + QK8_0/2 + l]*id; + + y[i].qs[ l] = roundf(v0); + y[i].qs[QK8_0/2 + l] = roundf(v1); + + sum0 += y[i].qs[ l]; + sum1 += y[i].qs[QK8_0/2 + l]; } - y[i].s = d * sum; + + y[i].s0 = d * sum0; + y[i].s1 = d * sum1; } } -#ifdef __AVX2__ -// There is no better way of doing this? -// I guess not, AVX is not very good at horizontal sums. -// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly -// faster than the solution below. As I don't have an AVX2 system handt right now to test, -// keeping the original. -// TODO: Please try and if it does make a differece, uncomment and remove the implementation below. -//static inline float horizontal_sum(__m256i a) { -// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a))); -// __m256i sum = _mm256_add_epi32(a, b); -// __m256i hi = _mm256_unpackhi_epi64(sum, sum); -// sum = _mm256_add_epi32(sum, hi); -// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); -//} -static inline float horizontal_sum(__m256i a) { - __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); - __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - __m128i sum64 = _mm_add_epi32(hi64, sum128); - __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} -#endif - static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1324,9 +1347,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].d = d; - int32x4_t accv = vdupq_n_s32(0); + int32x4_t accv0 = vdupq_n_s32(0); + int32x4_t accv1 = vdupq_n_s32(0); - for (int l = 0; l < 8; l++) { + // low half + for (int l = 0; l < 4; l++) { const float32x4_t v = vmulq_n_f32(srcv[l], id); const int32x4_t vi = vcvtnq_s32_f32(v); @@ -1335,12 +1360,30 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); - accv = vaddq_s32(accv, vi); + accv0 = vaddq_s32(accv0, vi); } - int32_t sum = vaddvq_s32(accv); - y[i].s = d * sum; + + // high half + for (int l = 4; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); + + accv1 = vaddq_s32(accv1, vi); + } + + const int32_t sum0 = vaddvq_s32(accv0); + const int32_t sum1 = vaddvq_s32(accv1); + + y[i].s0 = d * sum0; + y[i].s1 = d * sum1; } #elif defined(__AVX2__) || defined(__AVX__) + // TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! for (int i = 0; i < nb; i++) { // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); @@ -1386,9 +1429,10 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m256i i3 = _mm256_cvtps_epi32( v3 ); #if defined(__AVX2__) - // Compute the sum of the quants and set y[i].s - y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + //y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + y[i].s0 = d * hsum_i32_8(_mm256_add_epi32(i0, i1)); + y[i].s1 = d * hsum_i32_8(_mm256_add_epi32(i2, i3)); // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -1415,6 +1459,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m128i ni6 = _mm256_castsi256_si128( i3 ); __m128i ni7 = _mm256_extractf128_si256( i3, 1); + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = d * hsum_i32_8(_mm256_set_m128i(s1, s0)); + // Convert int32 to int16 ni0 = _mm_packs_epi32( ni0, ni1 ); ni2 = _mm_packs_epi32( ni2, ni3 ); @@ -1432,14 +1481,6 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // scalar quantize_row_q8_0_reference(x, y, k); #endif -#if defined __AVX__ - // TODO: vectorize this - for (int i=0; id * y0->s + x1->d * y1->s; + sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1); const uint8x16_t m4b = vdupq_n_u8(0xf); @@ -2443,7 +2482,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2461,32 +2500,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(bx, bx); - - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(by, bx); - - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - - const __m256i ones = _mm256_set1_epi16(1); - __m256i xy_q = _mm256_madd_epi16(ones, dot); - - /* Convert to vectore of 8 int32_t to 8 floats */ - __m256 q = _mm256_cvtepi32_ps( xy_q ); + const __m256 q = mul_sum_i8_pairs_float(bx, by); /* Multiply q with scale and accumulate */ acc = _mm256_fmadd_ps( d, q, acc ); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); + *s = hsum_float_8(acc); #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2525,15 +2545,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); + *s = hsum_float_8(acc); #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const float d0 = x[i].d; const float d1 = y[i].d; @@ -2555,9 +2570,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * } sumf += d0*d1*sumi; } -#endif - *s = sumf; +#endif } static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -2569,8 +2583,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const block_q4_1 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - // TODO: add AVX / WASM SIMD / etc #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); @@ -2584,7 +2596,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - summs += x0->m * y0->s + x1->m * y1->s; + summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1); const uint8x16_t m4b = vdupq_n_u8(0xf); @@ -2597,22 +2609,22 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + // interleave + const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); + const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); + const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h); + const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h); + // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - // interleave - const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h); - const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h); - const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); - const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); - #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); @@ -2637,7 +2649,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2648,9 +2660,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * for (int i = 0; i < nb; ++i) { const float * d0 = &x[i].d; const float * d1 = &y[i].d; - //const float * m0 = &x[i].m; - summs += x[i].m * y[i].s; + summs += x[i].m * (y[i].s0 + y[i].s1); const __m256 d0v = _mm256_broadcast_ss( d0 ); const __m256 d1v = _mm256_broadcast_ss( d1 ); @@ -2662,33 +2673,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8( bx, bx ); - - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8( by, bx ); - - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16( ax, sy ); - const __m256i ones = _mm256_set1_epi16( 1 ); - const __m256i xy_q = _mm256_madd_epi16( ones, dot ); - - // Convert to vector of 8 int32_t to 8 floats - const __m256 xy = _mm256_cvtepi32_ps( xy_q ); + const __m256 xy = mul_sum_i8_pairs_float(bx, by); // Accumulate d0*d1*x*y acc = _mm256_fmadd_ps( d0d1, xy, acc ); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ) + summs; + *s = hsum_float_8(acc) + summs; #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const float d0 = x[i].d; const float m0 = x[i].m; @@ -2710,9 +2704,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * sumf += f0*f2 + f1*f3; } } -#endif - *s = sumf; +#endif } static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -2725,8 +2718,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const block_q4_2 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -2804,7 +2795,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2826,32 +2817,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(bx, bx); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(by, bx); - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - - const __m256i ones = _mm256_set1_epi16(1); - __m256i xy_q = _mm256_madd_epi16(ones, dot); - - /* Convert to vectore of 8 int32_t to 8 floats */ - __m256 q = _mm256_cvtepi32_ps(xy_q); + const __m256 q = mul_sum_i8_pairs_float(bx, by); /* Multiply q with scale and accumulate */ acc = _mm256_fmadd_ps(d, q, acc); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps(acc, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(acc)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - - sumf = _mm_cvtss_f32(res); + *s = hsum_float_8(acc); #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const uint8_t * restrict x0 = x[2*i + 0].qs; const uint8_t * restrict x1 = x[2*i + 1].qs; @@ -2886,9 +2861,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * sumf += (d0 * y[i].d) * sumi_0; sumf += (d1 * y[i].d) * sumi_1; } -#endif - *s = sumf; +#endif } static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -2901,96 +2875,91 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const block_q4_3 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); - for (int i = 0; i < nb; i += 2) { + float summs0 = 0.0f; + float summs1 = 0.0f; + + for (int i = 0; i < nb; ++i) { const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1]; - const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0]; - const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1]; const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; - const uint8x16_t m4b = vdupq_n_u8(0xf); - - const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); - const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); - const float x1_0d = GGML_FP16_TO_FP32(x1_0->d); - const float x1_1d = GGML_FP16_TO_FP32(x1_1->d); - - const float x0_0m = GGML_FP16_TO_FP32(x0_0->m); - const float x0_1m = GGML_FP16_TO_FP32(x0_1->m); - const float x1_0m = GGML_FP16_TO_FP32(x1_0->m); - const float x1_1m = GGML_FP16_TO_FP32(x1_1->m); + summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; + summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); - const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs)); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf))); const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); // interleave const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); - const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h); - const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h); // load y const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l))); - const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h))); - - const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l))); - const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h))); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d); + const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); + const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d); #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + const __m256 dx = _mm256_set_m128(d1, d0); + + const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m)); + const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m)); + const __m256 mx = _mm256_set_m128(m1, m0); + + const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + const __m256i bx = _mm256_set_m128i(bx1, bx0); + + const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by); + const __m256 syf = sum_i16_pairs_float(syi); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf)); + acc = _mm256_fmadd_ps(sxy, dy, acc); + } + + *s = hsum_float_8(acc); #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const uint8_t * restrict x0 = x[2*i + 0].qs; const uint8_t * restrict x1 = x[2*i + 1].qs; @@ -3001,9 +2970,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); - int sy_0 = 0; - int sy_1 = 0; - int sxy_0 = 0; int sxy_1 = 0; @@ -3023,19 +2989,14 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const int y0_1 = y0[2*(j + QK8_0/4) + 0]; const int y1_1 = y0[2*(j + QK8_0/4) + 1]; - sy_0 += y0_0 + y1_0; - sy_1 += y0_1 + y1_1; - sxy_0 += x0_0*y0_0 + x1_0*y1_0; sxy_1 += x0_1*y0_1 + x1_1*y1_1; } - sumf += (d0*sxy_0 + m0*sy_0)*y[i].d; - sumf += (d1*sxy_1 + m1*sy_1)*y[i].d; + sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; } -#endif - *s = sumf; +#endif } diff --git a/llama.cpp b/llama.cpp index bb160804a..d5eeffc0b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -1794,7 +1795,7 @@ struct llama_context * llama_init_from_file( if (params.logits_all) { ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); } else { - ctx->logits.reserve(hparams.n_ctx); + ctx->logits.reserve(hparams.n_vocab); } if (params.embedding){ @@ -2258,4 +2259,123 @@ const char * llama_print_system_info(void) { // For internal test use std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { return ctx->model.tensors_by_name; -} \ No newline at end of file +} + +// Returns the size of the state +size_t llama_get_state_size(struct llama_context * ctx) { + const size_t s_bool = sizeof(int32_t); + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. + // for reference, std::mt19937(1337) serializes to 6701 bytes. + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = 64*1024; + const size_t s_logits_capacity = sizeof(size_t); + const size_t s_logits_size = sizeof(size_t); + const size_t s_logits = ctx->logits.capacity() * sizeof(float); + const size_t s_embedding_size = sizeof(size_t); + const size_t s_embedding = ctx->embedding.size() * sizeof(float); + const size_t s_kv_size = sizeof(size_t); + const size_t s_kv_ntok = sizeof(int); + const size_t s_kv = llama_get_kv_cache_size(ctx); + const size_t s_total = ( + + s_rng_size + + s_rng + + s_logits_capacity + + s_logits_size + + s_logits + + s_embedding_size + + s_embedding + + s_kv_size + + s_kv_ntok + + s_kv + ); + return s_total; +} + +// Copies the state to the specified destination address +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { + std::stringstream rng_ss; + rng_ss << ctx->rng; + const size_t rng_size = rng_ss.str().size(); + char rng_buf[64*1024]; + memset(&rng_buf[0], 0, 64*1024); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + const size_t logits_capacity = ctx->logits.capacity(); + const size_t logits_size = ctx->logits.size(); + const size_t embedding_size = ctx->embedding.size(); + const size_t kv_size = llama_get_kv_cache_size(ctx); + const int kv_ntok = llama_get_kv_cache_token_count(ctx); + + uint8_t * out = dest; + memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t); + memcpy(out, &rng_buf[0], 64*1024); out += 64*1024; + memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t); + memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t); + if (logits_size) { + memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); + } + out += logits_capacity * sizeof(float); + memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t); + if (embedding_size) { + memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float); + } + memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t); + memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int); + if (kv_size) { + memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size; + } + const size_t written = out - dest; + const size_t expected = llama_get_state_size(ctx); + LLAMA_ASSERT(written == expected); + return written; +} + +// Sets the state reading from the specified source address +size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { + size_t rng_size; + char rng_buf[64*1024]; + std::stringstream rng_ss; + + const uint8_t * in = src; + memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t); + memcpy(&rng_buf[0], in, 64*1024); in += 64*1024; + rng_ss.str(std::string(&rng_buf[0], rng_size)); + rng_ss >> ctx->rng; + LLAMA_ASSERT(rng_ss.fail() == false); + + size_t logits_capacity; + size_t logits_size; + size_t embedding_size; + size_t kv_size; + int kv_ntok; + + memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t); + memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t); + LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity); + if (logits_size) { + ctx->logits.resize(logits_size); + memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); + } + in += logits_capacity * sizeof(float); + memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t); + LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); + if (embedding_size) { + memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); + in += embedding_size * sizeof(float); + } + memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t); + memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int); + if (kv_size) { + LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); + void * k_data = ctx->model.kv_self.k->data; // remember data pointers + void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy + memcpy(ctx->model.kv_self.buf.addr, in, kv_size); + ctx->model.kv_self.k->data = k_data; // restore correct data pointers + ctx->model.kv_self.v->data = v_data; + in += kv_size; + } + ctx->model.kv_self.n = kv_ntok; + const size_t nread = in - src; + const size_t expected = llama_get_state_size(ctx); + LLAMA_ASSERT(nread == expected); + return nread; +} diff --git a/llama.h b/llama.h index e95ff73b8..f68a0cb40 100644 --- a/llama.h +++ b/llama.h @@ -129,6 +129,18 @@ extern "C" { size_t n_size, int n_token_count); + // Returns the size in bytes of the state (rng, logits, embedding and kv_cache) + LLAMA_API size_t llama_get_state_size(struct llama_context * ctx); + + // Copies the state to the specified destination address. + // Destination needs to have allocated enough memory. + // Returns the number of bytes copied + LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest); + + // Set the state reading from the specified address + // Returns the number of bytes read + LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src); + // Run the llama inference to obtain the logits and probabilities for the next token. // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls