iq4_xs: AVX2 dot product

This commit is contained in:
Iwan Kawrakow 2024-02-26 15:55:08 +02:00
parent fddbfe839a
commit 061a16f5a2
5 changed files with 91 additions and 83 deletions

View file

@ -571,7 +571,8 @@ typedef struct {
} block_iq4_nl; } block_iq4_nl;
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding"); static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
#define QR4_XS 4 // QR4_XS = 8 is very slightly faster than QR4_XS = 4
#define QR4_XS 8
#define QI4_XS (QK_K / (4*QR4_XS)) #define QI4_XS (QK_K / (4*QR4_XS))
typedef struct { typedef struct {
half d; half d;
@ -5341,21 +5342,57 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq; const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
const uint8_t * values = (const uint8_t *)kvalues_iq4nl; const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
// iqs is 0...15 //// iqs is 0...7
const int ib32 = iqs/2; //const int ib64 = iqs/2;
const int il = iqs%2; //const int il = iqs%2;
const int32_t * q8 = (const int *)bq8_1[ib32].qs + 2*il; //const int32_t * q8_1 = (const int *)bq8_1[2*ib64+0].qs + 2*il;
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32 + 2*il; //const int32_t * q8_2 = (const int *)bq8_1[2*ib64+1].qs + 2*il;
//const uint32_t * q4_1 = (const uint32_t *)bq4->qs + 8*ib64 + 2*il;
//const uint32_t * q4_2 = q4_1 + 4;
//const int8_t ls1 = (bq4->scales_l[ib64] & 0xf) | (((bq4->scales_h >> (4*ib64+0)) & 3) << 4);
//const int8_t ls2 = (bq4->scales_l[ib64] >> 4) | (((bq4->scales_h >> (4*ib64+2)) & 3) << 4);
//const float d1 = (float)bq4->d * (ls1 - 32) * __low2float(bq8_1[2*ib64+0].ds);
//const float d2 = (float)bq4->d * (ls2 - 32) * __low2float(bq8_1[2*ib64+1].ds);
//int v1, v2;
//int sumi1 = 0, sumi2 = 0;
//for (int j = 0; j < 2; ++j) {
// get_int_from_table_16(q4_1[j], values, v1, v2);
// sumi1 = __dp4a(v2, q8_1[j+4], __dp4a(v1, q8_1[j+0], sumi1));
// get_int_from_table_16(q4_2[j], values, v1, v2);
// sumi2 = __dp4a(v2, q8_2[j+4], __dp4a(v1, q8_2[j+0], sumi2));
//}
//return d1 * sumi1 + d2 * sumi2;
// iqs is 0...7
const int ib32 = iqs;
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4); const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds); const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
int v1, v2; int v1, v2;
int sumi1 = 0, sumi2 = 0; int sumi1 = 0, sumi2 = 0;
for (int j = 0; j < 2; ++j) { for (int j = 0; j < 4; ++j) {
get_int_from_table_16(q4[j], values, v1, v2); get_int_from_table_16(q4[j], values, v1, v2);
sumi1 = __dp4a(v1, q8[j+0], sumi1); sumi1 = __dp4a(v1, q8[j+0], sumi1);
sumi2 = __dp4a(v2, q8[j+4], sumi2); sumi2 = __dp4a(v2, q8[j+4], sumi2);
} }
return d * (sumi1 + sumi2); return d * (sumi1 + sumi2);
//// iqs is 0...15
//const int ib32 = iqs/2;
//const int il = iqs%2;
//const int32_t * q8 = (const int *)bq8_1[ib32].qs + 2*il;
//const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32 + 2*il;
//const int8_t ls = ((bq4->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((bq4->scales_h >> 2*ib32) & 3) << 4);
//const float d = (float)bq4->d * (ls - 32) * __low2float(bq8_1[ib32].ds);
//int v1, v2;
//int sumi1 = 0, sumi2 = 0;
//for (int j = 0; j < 2; ++j) {
// get_int_from_table_16(q4[j], values, v1, v2);
// sumi1 = __dp4a(v1, q8[j+0], sumi1);
// sumi2 = __dp4a(v2, q8[j+4], sumi2);
//}
//return d * (sumi1 + sumi2);
#else #else
assert(false); assert(false);
return 0.f; return 0.f;
@ -5366,41 +5403,6 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
#endif #endif
} }
//// TODO
//static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
// const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
//
// const block_iq4_xs * bq = (const block_iq4_xs *) vbq;
//
//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
// const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
// const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
//
// const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
//
// int v1, v2;
// int sumi1 = 0, sumi2 = 0;
// for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
// const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
// get_int_from_table_16(aux, values, v1, v2);
// sumi1 = __dp4a(v1, q8[l+0], sumi1);
// sumi2 = __dp4a(v2, q8[l+4], sumi2);
// }
//
//#else
// const uint8_t * q4 = bq->qs + 4*iqs;
// const int8_t * q8 = bq8_1->qs + 4*iqs;
//
// int sumi1 = 0, sumi2 = 0;
// for (int l = 0; l < 4*VDR_Q4_0_Q8_1_MMVQ; ++l) {
// sumi1 += q8[l+ 0] * kvalues_iq4nl[q4[l] & 0xf];
// sumi2 += q8[l+16] * kvalues_iq4nl[q4[l] >> 4];
// }
//#endif
// const float d = (float)bq->d * __low2float(bq8_1->ds);
// return d * (sumi1 + sumi2);
//}
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps, template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot> allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
static __device__ __forceinline__ void mul_mat_q( static __device__ __forceinline__ void mul_mat_q(

View file

@ -10444,21 +10444,20 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
#endif #endif
} }
void ggml_vec_dot_iq4_xs_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(nrc == 1); assert(nrc == 1);
UNUSED(nrc); UNUSED(nrc);
UNUSED(bx); UNUSED(bx);
UNUSED(by); UNUSED(by);
UNUSED(bs); UNUSED(bs);
assert(n % QK_K == 0); assert(n % QK_K == 0);
static_assert(QK8_0 == 32, "QK8_0 must be 32");
const block_iq4_xs * restrict x = vx; const block_iq4_xs * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_K * restrict y = vy;
const int nb = n / QK_K; const int nb = n / QK_K;
#if defined z__ARM_NEON #if defined __ARM_NEON
const int8x16_t values = vld1q_s8(kvalues_iq4nl); const int8x16_t values = vld1q_s8(kvalues_iq4nl);
const uint8x16_t m4b = vdupq_n_u8(0x0f); const uint8x16_t m4b = vdupq_n_u8(0x0f);
uint8x16x2_t q4bits; uint8x16x2_t q4bits;
@ -10492,65 +10491,72 @@ void ggml_vec_dot_iq4_xs_q8_0(int n, float * restrict s, size_t bs, const void *
*s = sumf; *s = sumf;
#elif defined z__AVX2__ #elif defined __AVX2__
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl);
const __m128i m4b = _mm_set1_epi8(0x0f); const __m128i m4b = _mm_set1_epi8(0x0f);
const __m256i mone = _mm256_set1_epi16(1);
__m256 accum1 = _mm256_setzero_ps(); __m256 accum = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps(); for (int ibl = 0; ibl < nb; ++ibl) {
for (int ib = 0; ib < nb; ib += 2) { const uint8_t * qs = x[ibl].qs;
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[0].qs); const int8_t * q8 = y[ibl].qs;
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[1].qs); uint16_t sh = x[ibl].scales_h;
const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[0].qs); __m256i sumi1 = _mm256_setzero_si256();
const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[1].qs); __m256i sumi2 = _mm256_setzero_si256();
const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), for (int ib = 0; ib < QK_K/32; ib += 2) {
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16;
_mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i q4b_1 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); const __m256i q4b_2 = _mm256_set_m128i(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)*GGML_FP16_TO_FP32(x[0].d)), _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
_mm256_cvtepi32_ps(p_1), accum1); const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)*GGML_FP16_TO_FP32(x[1].d)), const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
_mm256_cvtepi32_ps(p_2), accum2); const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32;
y += 2; sh >>= 4;
x += 2; const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1));
const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2));
sumi1 = _mm256_add_epi32(p_1, sumi1);
sumi2 = _mm256_add_epi32(p_2, sumi2);
}
accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d),
_mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum);
} }
*s = hsum_float_8(_mm256_add_ps(accum1, accum2)); *s = hsum_float_8(accum);
#else #else
float sumf = 0; float sumf = 0;
for (int ibl = 0; ibl < nb; ++ibl) { for (int ibl = 0; ibl < nb; ++ibl) {
const float d4 = GGML_FP16_TO_FP32(x[ibl].d); const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d;
uint16_t h = x[ibl].scales_h; uint16_t h = x[ibl].scales_h;
const uint8_t * qs = x[ibl].qs; const uint8_t * qs = x[ibl].qs;
const int8_t * q8 = y[ibl].qs;
for (int ib = 0; ib < QK_K/32; ib += 2) { for (int ib = 0; ib < QK_K/32; ib += 2) {
const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30);
const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30);
h >>= 4; h >>= 4;
const float d1 = GGML_FP16_TO_FP32(y[0].d)*d4*(ls1 - 32); const float d1 = d4d8*(ls1 - 32);
const float d2 = GGML_FP16_TO_FP32(y[1].d)*d4*(ls2 - 32); const float d2 = d4d8*(ls2 - 32);
int sumi1 = 0, sumi2 = 0; int sumi1 = 0, sumi2 = 0;
for (int j = 0; j < 16; ++j) { for (int j = 0; j < 16; ++j) {
sumi1 += y[0].qs[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
sumi2 += y[0].qs[j+16] * kvalues_iq4nl[qs[j] >> 4]; sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
} }
sumf += d1 * (sumi1 + sumi2); sumf += d1 * (sumi1 + sumi2);
qs += 16; qs += 16;
q8 += 32;
sumi1 = sumi2 = 0; sumi1 = sumi2 = 0;
for (int j = 0; j < 16; ++j) { for (int j = 0; j < 16; ++j) {
sumi1 += y[1].qs[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf];
sumi2 += y[1].qs[j+16] * kvalues_iq4nl[qs[j] >> 4]; sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4];
} }
sumf += d2 * (sumi1 + sumi2); sumf += d2 * (sumi1 + sumi2);
qs += 16; qs += 16;
y += 2; q8 += 32;
} }
} }
*s = sumf; *s = sumf;

View file

@ -322,7 +322,7 @@ void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq4_xs_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
// //

4
ggml.c
View file

@ -734,8 +734,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs, .to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
.from_float = quantize_row_iq4_xs, .from_float = quantize_row_iq4_xs,
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference, .from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
.vec_dot = ggml_vec_dot_iq4_xs_q8_0, .vec_dot = ggml_vec_dot_iq4_xs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_Q8_K] = { [GGML_TYPE_Q8_K] = {

View file

@ -1918,7 +1918,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
GGML_TYPE_Q6_K, GGML_TYPE_Q6_K,
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
}; };
// unary ops // unary ops