Q2 AVX2: do two blocks at a time, by @slaren
This commit is contained in:
parent
6fc51a8c05
commit
c29ab90e06
1 changed files with 50 additions and 26 deletions
76
ggml.c
76
ggml.c
|
@ -488,6 +488,34 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
|
||||||
}
|
}
|
||||||
|
|
||||||
#if __AVX2__ || __AVX512F__
|
#if __AVX2__ || __AVX512F__
|
||||||
|
// Unpack 32 2-bit fields into 32 bytes
|
||||||
|
// The output vector contains 32 bytes, each one in [ 0 .. 3 ] interval
|
||||||
|
static inline __m256i bytes_from_crumbs(uint32_t packed_hi, uint32_t packed_lo) {
|
||||||
|
__m128i bx_hi = _mm_set1_epi32(packed_hi);
|
||||||
|
__m128i bx_lo = _mm_set1_epi32(packed_lo);
|
||||||
|
__m256i bx = _mm256_set_m128i(bx_hi, bx_lo);
|
||||||
|
|
||||||
|
// shift counts to get all bit pairs in lowest position of each byte
|
||||||
|
const __m256i shift256 = _mm256_set_epi32(6, 4, 2, 0,
|
||||||
|
6, 4, 2, 0);
|
||||||
|
bx = _mm256_srlv_epi32(bx, shift256);
|
||||||
|
|
||||||
|
const __m256i shufmask = _mm256_set_epi8(15,11, 7, 3,
|
||||||
|
14,10, 6, 2,
|
||||||
|
13, 9, 5, 1,
|
||||||
|
12, 8, 4, 0,
|
||||||
|
15,11, 7, 3,
|
||||||
|
14,10, 6, 2,
|
||||||
|
13, 9, 5, 1,
|
||||||
|
12, 8, 4, 0);
|
||||||
|
bx = _mm256_shuffle_epi8(bx, shufmask);
|
||||||
|
|
||||||
|
const __m256i mask = _mm256_set1_epi8(3);
|
||||||
|
bx = _mm256_and_si256(mask, bx);
|
||||||
|
|
||||||
|
return bx;
|
||||||
|
}
|
||||||
|
|
||||||
// Unpack 32 4-bit fields into 32 bytes
|
// Unpack 32 4-bit fields into 32 bytes
|
||||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||||
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
||||||
|
@ -2500,6 +2528,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
||||||
static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
assert(n % QK2_0 == 0);
|
assert(n % QK2_0 == 0);
|
||||||
const int nb = n / QK2_0;
|
const int nb = n / QK2_0;
|
||||||
|
assert(nb % 2 == 0);
|
||||||
|
|
||||||
const block_q2_0 * restrict x = vx;
|
const block_q2_0 * restrict x = vx;
|
||||||
const block_q8_0 * restrict y = vy;
|
const block_q8_0 * restrict y = vy;
|
||||||
|
@ -2508,49 +2537,44 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
|
||||||
|
|
||||||
#if defined(__AVX2__)
|
#if defined(__AVX2__)
|
||||||
// Initialize accumulator with zeros
|
// Initialize accumulator with zeros
|
||||||
__m128 acc = _mm_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; i += 2) {
|
||||||
|
__m256i bx = bytes_from_crumbs(x[i+1].qs, x[i].qs);
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
|
||||||
// Compute combined scale for the block
|
// Compute combined scale for the block
|
||||||
const __m128 scale = _mm_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i/2].d);
|
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+0].d) * y[i/2].d);
|
||||||
|
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+1].d) * y[i/2].d);
|
||||||
|
const __m256 scale = _mm256_set_m128(scale_hi, scale_lo);
|
||||||
|
|
||||||
__m128i bx = _mm_set1_epi32(x[i].qs);
|
const __m256i off = _mm256_set1_epi8(2);
|
||||||
|
bx = _mm256_sub_epi8(bx, off);
|
||||||
|
|
||||||
// shift counts to get all bit pairs in lowest position of each byte
|
// Load y vector
|
||||||
const __m128i shift128 = _mm_set_epi32(6, 4, 2, 0);
|
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i/2].qs);
|
||||||
bx = _mm_srlv_epi32(bx, shift128);
|
|
||||||
|
|
||||||
const __m128i shufmask = _mm_set_epi8(15,11,7,3,14,10,6,2,13,9,5,1,12,8,4,0);
|
|
||||||
bx = _mm_shuffle_epi8(bx, shufmask);
|
|
||||||
|
|
||||||
const __m128i mask = _mm_set1_epi8(3);
|
|
||||||
bx = _mm_and_si128(mask, bx);
|
|
||||||
|
|
||||||
const __m128i off = _mm_set1_epi8(2);
|
|
||||||
bx = _mm_sub_epi8(bx, off);
|
|
||||||
|
|
||||||
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i/2].qs + (i%2)*QK2_0));
|
|
||||||
|
|
||||||
// Get absolute values of x vectors
|
// Get absolute values of x vectors
|
||||||
const __m128i ax = _mm_sign_epi8(bx, bx);
|
const __m256i ax = _mm256_sign_epi8(bx, bx);
|
||||||
// Sign the values of the y vectors
|
// Sign the values of the y vectors
|
||||||
const __m128i sy = _mm_sign_epi8(by, bx);
|
const __m256i sy = _mm256_sign_epi8(by, bx);
|
||||||
// Perform multiplication and create 16-bit values
|
// Perform multiplication and create 16-bit values
|
||||||
const __m128i dot = _mm_maddubs_epi16(ax, sy);
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||||
|
|
||||||
// Convert int16_t to int32_t by adding pairwise
|
// Convert int16_t to int32_t by adding pairwise
|
||||||
const __m128i ones = _mm_set1_epi16(1);
|
const __m256i ones = _mm256_set1_epi16(1);
|
||||||
__m128i i32 = _mm_madd_epi16(dot, ones);
|
__m256i i32 = _mm256_madd_epi16(ones, dot);
|
||||||
|
|
||||||
// Convert int32_t to float
|
// Convert int32_t to float
|
||||||
const __m128 p = _mm_cvtepi32_ps(i32);
|
__m256 p = _mm256_cvtepi32_ps(i32);
|
||||||
|
|
||||||
// Apply the scale, and accumulate
|
// Apply the scale, and accumulate
|
||||||
acc = _mm_fmadd_ps(scale, p, acc);
|
acc = _mm256_fmadd_ps(scale, p, acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return horizontal sum of the acc vector
|
// Return horizontal sum of the acc vector
|
||||||
__m128 res = _mm_add_ps(acc, _mm_movehl_ps(acc, acc));
|
__m128 res = _mm256_extractf128_ps(acc, 1);
|
||||||
|
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
|
||||||
|
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
|
||||||
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
res = _mm_add_ss(res, _mm_movehdup_ps(res));
|
||||||
sumf = _mm_cvtss_f32(res);
|
sumf = _mm_cvtss_f32(res);
|
||||||
#else
|
#else
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue