iq3_s_mult_shuffle: mult + shuffle based codebook
This commit is contained in:
parent
b48bf8b411
commit
b587482287
2 changed files with 81 additions and 46 deletions
|
@ -4058,12 +4058,18 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
|
|||
// Best PPL
|
||||
#define IQ3S_MULTIPLIER 190842953
|
||||
#else
|
||||
#define IQ3S_MULTIPLIER 898886
|
||||
#define IQ3S_MULTIPLIER 72968561ULL
|
||||
//#define IQ3S_MULTIPLIER 540201
|
||||
//#define IQ3S_MULTIPLIER 1378231
|
||||
//#define IQ3S_MULTIPLIER 898886
|
||||
//#define IQ3S_MULTIPLIER 842866
|
||||
#endif
|
||||
|
||||
#define IQ3S_BITS 3
|
||||
|
||||
static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 5, 7, 7, 9, 9, 11, 13, 15};
|
||||
//static const uint8_t iq3s_values[16] = {1, 1, 1, 3, 3, 3, 5, 5, 7, 7, 9, 9, 11, 11, 13, 15};
|
||||
|
||||
void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, int k) {
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
@ -4099,10 +4105,15 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
|
|||
y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
#else
|
||||
aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
|
||||
aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
|
||||
//aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
|
||||
//aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
|
||||
//for (int j = 0; j < 8; ++j) {
|
||||
// y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
//}
|
||||
aux32[0] = (((qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
|
||||
aux32[1] = (((qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
#endif
|
||||
y += 8;
|
||||
|
@ -4118,12 +4129,17 @@ void dequantize_row_iq3_s(const block_iq3_s * restrict x, float * restrict y, in
|
|||
y[j] = dl * (2*((grid[j]-1)/2) + 1) * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
#else
|
||||
aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
|
||||
aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
|
||||
#endif
|
||||
//aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
|
||||
//aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f) | 0x01010101;
|
||||
//for (int j = 0; j < 8; ++j) {
|
||||
// y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
//}
|
||||
aux32[0] = (((qs[2*l+0] | ((qh[1] << (8-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
|
||||
aux32[1] = (((qs[2*l+1] | ((qh[1] << (7-2*l)) & 256)) * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
y[j] = dl * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
y[j] = dl * iq3s_values[grid[j]] * (signs[l] & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
#endif
|
||||
y += 8;
|
||||
}
|
||||
qh += 2;
|
||||
|
@ -10073,12 +10089,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
|
||||
const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1);
|
||||
const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2);
|
||||
const __m128i shuffle128 = _mm_loadu_si128((const __m128i *)iq3s_values);
|
||||
const __m256i shuffle = _mm256_set_m128i(shuffle128, shuffle128);
|
||||
|
||||
const __m256i idx_mask = _mm256_set1_epi32(256);
|
||||
const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8);
|
||||
const __m256i idx_mult = _mm256_set1_epi32(IQ3S_MULTIPLIER);
|
||||
const __m256i m1 = _mm256_set1_epi8(1);
|
||||
//const __m256i m1 = _mm256_set1_epi8(1);
|
||||
const __m256i m15 = _mm256_set1_epi32(0x0f0f0f0f);
|
||||
const __m256i m100 = _mm256_set1_epi32(0x0100);
|
||||
#ifdef IQ3S_SLOW_MULT
|
||||
const __m256i m7 = _mm256_set1_epi32(0x07070707);
|
||||
const __m256i m0 = _mm256_setzero_si256();
|
||||
|
@ -10096,12 +10113,19 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
|
||||
const __m128i idx_l_8 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
|
||||
const __m256i idx_l_16 = _mm256_cvtepu8_epi16(idx_l_8);
|
||||
const __m256i idx_h_l = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+0]), idx_shift), idx_mask);
|
||||
const __m256i idx_h_h = _mm256_and_si256(_mm256_sllv_epi32(_mm256_set1_epi32(qh[ib32+1]), idx_shift), idx_mask);
|
||||
const __m256i idx_32_l = _mm256_or_si256(idx_h_l, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l_16)));
|
||||
const __m256i idx_32_h = _mm256_or_si256(idx_h_h, _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l_16, 1)));
|
||||
|
||||
const __m256i q3_low_bytes_1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8;
|
||||
const __m256i q3_low_bytes_2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)qs)); qs += 8;
|
||||
uint64_t high_bits_spread_1 = ((uint64_t)qh[ib32+0] * 0x0101010101010101ULL) & 0x8040201008040201ULL;
|
||||
uint64_t high_bits_spread_2 = ((uint64_t)qh[ib32+1] * 0x0101010101010101ULL) & 0x8040201008040201ULL;
|
||||
const __m256i high_bits_in_low_1 = _mm256_cmpgt_epi32(
|
||||
_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_1)),
|
||||
_mm256_setzero_si256());
|
||||
const __m256i high_bits_in_low_2 = _mm256_cmpgt_epi32(
|
||||
_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)&high_bits_spread_2)),
|
||||
_mm256_setzero_si256());
|
||||
const __m256i idx_32_l = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_1), q3_low_bytes_1);
|
||||
const __m256i idx_32_h = _mm256_or_si256(_mm256_and_si256(m100, high_bits_in_low_2), q3_low_bytes_2);
|
||||
|
||||
#ifdef IQ3S_SLOW_MULT
|
||||
const __m256i idx_l = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1), m0);
|
||||
|
@ -10109,12 +10133,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
const __m256i idx_h = _mm256_max_epi8(_mm256_sub_epi8(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1), m0);
|
||||
const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
|
||||
#else
|
||||
//const __m256i idx_l = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1);
|
||||
//const __m256i q2_1 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_l, 1), m7), 1), m1);
|
||||
//const __m256i idx_h = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1);
|
||||
//const __m256i q2_2 = _mm256_or_si256(_mm256_slli_epi32(_mm256_and_si256(_mm256_srli_epi32(idx_h, 1), m7), 1), m1);
|
||||
const __m256i q2_1 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15), m1);
|
||||
const __m256i q2_2 = _mm256_or_si256(_mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15), m1);
|
||||
const __m256i q2_1 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_l), m15));
|
||||
const __m256i q2_2 = _mm256_shuffle_epi8(shuffle, _mm256_and_si256(_mm256_mullo_epi32(idx_mult, idx_32_h), m15));
|
||||
#endif
|
||||
|
||||
__m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16));
|
||||
|
@ -11364,10 +11384,14 @@ static void iq3xs_init_grid512(void) {
|
|||
#ifdef IQ3S_SLOW_MULT
|
||||
aux32 = ((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f;
|
||||
#else
|
||||
aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101;
|
||||
//aux32 = (((uint64_t)IQ3S_MULTIPLIER * k) & 0x0f0f0f0f) | 0x01010101;
|
||||
aux32 = ((k * IQ3S_MULTIPLIER) & 0x0f0f0f0f);
|
||||
#endif
|
||||
//for (int i = 0; i < 4; ++i) {
|
||||
// pos[i] = 2*((q4[i]-1)/2) + 1;
|
||||
//}
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
pos[i] = 2*((q4[i]-1)/2) + 1;
|
||||
pos[i] = iq3s_values[q4[i]];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue