Q8: use int8_t, AVX/AVX2 optimizations

This commit is contained in:
Stephan Walter 2023-04-13 18:55:55 +02:00 committed by Georgi Gerganov
parent 19e7a6575d
commit 2c4f9b658d
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

214
ggml.c
View file

@ -592,7 +592,7 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 bloc
typedef struct {
float d; // delta
uint8_t qs[QK]; // nibbles / quants
int8_t qs[QK]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK, "wrong q8_0 block size/padding");
@ -1069,9 +1069,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
for (int l = 0; l < QK; ++l) {
const float v = x[i*QK + l]*id;
const uint8_t vi = (int8_t)roundf(v) + 128;
y[i].qs[l] = vi;
y[i].qs[l] = roundf(v);
}
}
}
@ -1104,8 +1102,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(srcv[l], id);
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(128.5f));
const int32x4_t vi = vcvtq_s32_f32(vf);
//TODO: rounding
const int32x4_t vi = vcvtq_s32_f32(v);
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
@ -1113,6 +1111,90 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
__m256 v2 = _mm256_loadu_ps( x + 16 );
__m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32;
// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
// Quantize these floats
const float d = maxScalar / 127.f;
y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier
v0 = _mm256_mul_ps( v0, mul );
v1 = _mm256_mul_ps( v1, mul );
v2 = _mm256_mul_ps( v2, mul );
v3 = _mm256_mul_ps( v3, mul );
// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
#if defined(__AVX2__)
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
#else
// Since we don't have in AVX some necessary functions,
// we split the registers in half and call AVX2 analogs from SSE
__m128i ni0 = _mm256_castsi256_si128( i0 );
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
__m128i ni2 = _mm256_castsi256_si128( i1 );
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
__m128i ni4 = _mm256_castsi256_si128( i2 );
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
__m128i ni6 = _mm256_castsi256_si128( i3 );
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
// Convert int32 to int16
ni0 = _mm_packs_epi32( ni0, ni1 );
ni2 = _mm_packs_epi32( ni2, ni3 );
ni4 = _mm_packs_epi32( ni4, ni5 );
ni6 = _mm_packs_epi32( ni6, ni7 );
// Convert int16 to int8
ni0 = _mm_packs_epi16( ni0, ni2 );
ni4 = _mm_packs_epi16( ni4, ni6 );
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
#endif
}
#else
// scalar
quantize_row_q8_0_reference(x, y, k);
@ -2517,7 +2599,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int8x16_t s8b = vdupq_n_s8(0x8);
const uint8x16_t u128b = vdupq_n_u8(128);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@ -2535,21 +2616,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
// load y
const uint8x16_t v1_0l = vld1q_u8(y0->qs);
const uint8x16_t v1_0h = vld1q_u8(y0->qs + 16);
const uint8x16_t v1_1l = vld1q_u8(y1->qs);
const uint8x16_t v1_1h = vld1q_u8(y1->qs + 16);
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
// interleave
const uint8x16_t v1_0lz = vuzp1q_u8(v1_0l, v1_0h);
const uint8x16_t v1_0hz = vuzp2q_u8(v1_0l, v1_0h);
const uint8x16_t v1_1lz = vuzp1q_u8(v1_1l, v1_1h);
const uint8x16_t v1_1hz = vuzp2q_u8(v1_1l, v1_1h);
const int8x16_t v1_0ls = vreinterpretq_s8_u8(vsubq_u8(v1_0lz, u128b));
const int8x16_t v1_0hs = vreinterpretq_s8_u8(vsubq_u8(v1_0hz, u128b));
const int8x16_t v1_1ls = vreinterpretq_s8_u8(vsubq_u8(v1_1lz, u128b));
const int8x16_t v1_1hs = vreinterpretq_s8_u8(vsubq_u8(v1_1hz, u128b));
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
@ -2587,6 +2663,94 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
}
sumf = sum0 + sum1;
#elif defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
__m256i bx = bytesFromNibbles(x[i].qs);
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8( bx, off );
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
const __m256i ones = _mm256_set1_epi16(1);
__m256i xy_q = _mm256_madd_epi16(ones, dot);
/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps( xy_q );
/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( d, q, acc );
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Main loop
for (int i = 0; i < nb; ++i) {
// Compute combined scale for the block
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
__m128i i32[2];
for (int j = 0; j < 2; ++j) {
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m128i off = _mm_set1_epi8( 8 );
bx = _mm_sub_epi8( bx, off );
// Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(bx, bx);
// Sign the values of the y vectors
const __m128i sy = _mm_sign_epi8(by, bx);
// Perform multiplication and create 16-bit values
const __m128i dot = _mm_maddubs_epi16(ax, sy);
const __m128i ones = _mm_set1_epi16(1);
i32[j] = _mm_madd_epi16(ones, dot);
}
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
// Apply the scale, and accumulate
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
}
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
#else
// scalar
for (int i = 0; i < nb; i++) {
@ -2594,7 +2758,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const float d1 = y[i].d;
const uint8_t * restrict p0 = x[i].qs;
const uint8_t * restrict p1 = y[i].qs;
const int8_t * restrict p1 = y[i].qs;
int sumi = 0;
for (int j = 0; j < QK/2; j++) {
@ -2603,10 +2767,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
const int i0 = (int8_t) (v0 & 0xf) - 8;
const int i1 = (int8_t) (v0 >> 4) - 8;
const int i2 = (int) p1[2*j + 0] - 128;
const int i3 = (int) p1[2*j + 1] - 128;
/*printf("dot product: i0=%4d i1=%4d i2=%4d i3=%4d\n", i0, i1, i2, i3);*/
const int i2 = p1[2*j + 0];
const int i3 = p1[2*j + 1];
sumi += i0*i2 + i1*i3;
}
@ -10171,7 +10333,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
} else
#endif
{
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
}
} else {
GGML_ASSERT(false);
}