split to functions

This commit is contained in:
Eve 2024-11-02 18:08:13 -04:00
parent 7de0bdc2db
commit b8d592fe2c

View file

@ -149,6 +149,28 @@ static inline __m128i packNibbles( __m256i bytes )
#endif #endif
} }
#elif defined(__AVX__) #elif defined(__AVX__)
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
const __m128i lowByte = _mm_set1_epi16( 0xFF );
__m128i high = _mm_andnot_si128( lowByte, bytes1 );
__m128i low = _mm_and_si128( lowByte, bytes1 );
high = _mm_srli_epi16( high, 4 );
bytes1 = _mm_or_si128( low, high );
high = _mm_andnot_si128( lowByte, bytes2 );
low = _mm_and_si128( lowByte, bytes2 );
high = _mm_srli_epi16( high, 4 );
bytes2 = _mm_or_si128( low, high );
return _mm_packus_epi16( bytes1, bytes2);
}
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
const __m128i ax = _mm_sign_epi8(x, x);
const __m128i sy = _mm_sign_epi8(y, x);
return _mm_maddubs_epi16(ax, sy);
}
// spread 32 bits to 32 bytes { 0x00, 0xFF } // spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) { static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32; uint32_t x32;
@ -216,26 +238,23 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
return sum_i16_pairs_float(doth, dotl); return sum_i16_pairs_float(doth, dotl);
} }
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) // larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors
{ static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1,
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) {
const __m128i lowByte = _mm_set1_epi16( 0xFF ); const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0);
__m128i high = _mm_andnot_si128( lowByte, bytes1 ); const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1);
__m128i low = _mm_and_si128( lowByte, bytes1 ); const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0);
high = _mm_srli_epi16( high, 4 ); const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1);
bytes1 = _mm_or_si128( low, high ); __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
high = _mm_andnot_si128( lowByte, bytes2 ); __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
low = _mm_and_si128( lowByte, bytes2 ); return sum_i16_pairs_float(p_2, p_1);
high = _mm_srli_epi16( high, 4 );
bytes2 = _mm_or_si128( low, high );
return _mm_packus_epi16( bytes1, bytes2);
} }
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { // fp16 delta calculation intended for mul_sum_i8_quad_float
const __m128i ax = _mm_sign_epi8(x, x); static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) {
const __m128i sy = _mm_sign_epi8(y, x); // GGML_FP16_TO_FP32 is faster than Intel F16C
return _mm_maddubs_epi16(ax, sy); return _mm256_set_m128(_mm_set1_ps(GGML_FP16_TO_FP32(x1) * GGML_FP16_TO_FP32(y1)),
_mm_set1_ps(GGML_FP16_TO_FP32(x0) * GGML_FP16_TO_FP32(y0)));
} }
#endif #endif
#elif defined(__SSSE3__) #elif defined(__SSSE3__)
@ -4227,18 +4246,9 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8)); const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8)); const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d);
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
__m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1);
p_1 = _mm_madd_epi16(p_1, _mm_set1_epi16(1));
__m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1);
p_2 = _mm_madd_epi16(p_2, _mm_set1_epi16(1));
const __m256 deltas = _mm256_set_m128(_mm_set1_ps(GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d)),
_mm_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)));
accum = _mm256_add_ps(_mm256_mul_ps(deltas, _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1))), accum);
} }
sumf = hsum_float_8(accum); sumf = hsum_float_8(accum);