q4_0c: AVX512 vec_dot and quantize impl
This commit is contained in:
parent
4bd781cd25
commit
ab543dc1a4
1 changed files with 124 additions and 17 deletions
141
ggml.c
141
ggml.c
|
@ -1725,8 +1725,8 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
|
||||||
|
|
||||||
// reference implementation for deterministic creation of model files
|
// reference implementation for deterministic creation of model files
|
||||||
static void quantize_row_q8_0c_reference(const float * restrict x, void * restrict y, int k) {
|
static void quantize_row_q8_0c_reference(const float * restrict x, void * restrict y, int k) {
|
||||||
assert(k % QK8_0 == 0);
|
assert(k % QK8_0C == 0);
|
||||||
const int nb = k / QK8_0;
|
const int nb = k / QK8_0C;
|
||||||
|
|
||||||
uint8_t * restrict qs = y;
|
uint8_t * restrict qs = y;
|
||||||
float * restrict ds = (float *) ((uint8_t *) y + QK8_0C * nb);
|
float * restrict ds = (float *) ((uint8_t *) y + QK8_0C * nb);
|
||||||
|
@ -1734,8 +1734,8 @@ static void quantize_row_q8_0c_reference(const float * restrict x, void * restri
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
float amax = 0.0f; // absolute max
|
float amax = 0.0f; // absolute max
|
||||||
|
|
||||||
for (int l = 0; l < QK8_0; l++) {
|
for (int l = 0; l < QK8_0C; l++) {
|
||||||
const float v = x[i*QK8_0 + l];
|
const float v = x[i*QK8_0C + l];
|
||||||
amax = MAX(amax, fabsf(v));
|
amax = MAX(amax, fabsf(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1744,17 +1744,46 @@ static void quantize_row_q8_0c_reference(const float * restrict x, void * restri
|
||||||
|
|
||||||
ds[i] = d;
|
ds[i] = d;
|
||||||
|
|
||||||
for (int l = 0; l < QK8_0; ++l) {
|
for (int l = 0; l < QK8_0C; ++l) {
|
||||||
const float v = x[i*QK8_0 + l]*id;
|
const float v = x[i*QK8_0C + l]*id;
|
||||||
qs[i*QK8_0 + l] = roundf(v);
|
qs[i*QK8_0C + l] = roundf(v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int k) {
|
static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int k) {
|
||||||
assert(k % QK8_0 == 0);
|
assert(k % QK8_0C == 0);
|
||||||
|
const int nb = k / QK8_0C;
|
||||||
|
|
||||||
|
int8_t * restrict qs = vy;
|
||||||
|
float * restrict ds = (float *) ((uint8_t *) vy + nb*QK8_0C);
|
||||||
|
|
||||||
|
#if __AVX512F__
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
const __m512 x0 = _mm512_loadu_ps( x + i*QK8_0C );
|
||||||
|
const __m512 x1 = _mm512_loadu_ps( x + i*QK8_0C + QK8_0C/2);
|
||||||
|
|
||||||
|
// Find absolute max
|
||||||
|
const __m512 x0abs = _mm512_abs_ps(x0);
|
||||||
|
const __m512 x1abs = _mm512_abs_ps(x1);
|
||||||
|
const float amax = _mm512_reduce_max_ps(_mm512_max_ps(x0abs, x1abs));
|
||||||
|
|
||||||
|
const float d = amax / ((1 << 7) - 1);
|
||||||
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
|
||||||
|
ds[i] = d;
|
||||||
|
|
||||||
|
const __m512 mul = _mm512_set1_ps( id );
|
||||||
|
const __m512i x0q = _mm512_cvt_roundps_epi32(_mm512_mul_ps(x0, mul), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||||
|
const __m512i x1q = _mm512_cvt_roundps_epi32(_mm512_mul_ps(x1, mul), (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||||
|
|
||||||
|
_mm512_mask_cvtepi32_storeu_epi8(qs + i*QK8_0C, 0xffff, x0q);
|
||||||
|
_mm512_mask_cvtepi32_storeu_epi8(qs + i*QK8_0C + QK8_0C/2, 0xffff, x1q);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// scalar
|
||||||
quantize_row_q8_0c_reference(x, vy, k);
|
quantize_row_q8_0c_reference(x, vy, k);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
||||||
|
@ -2780,6 +2809,73 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if __AVX512F__ && QK4_0 == 32
|
||||||
|
|
||||||
|
// Dot product of four blocks of q4_0c with four blocks of q8_0c
|
||||||
|
static inline __m512 dot_q4_0c_fourblocks_avx512(
|
||||||
|
__m512 acc,
|
||||||
|
const uint8_t * restrict xqs,
|
||||||
|
const float * restrict xds,
|
||||||
|
const int8_t * restrict yqs,
|
||||||
|
const float * restrict yds
|
||||||
|
) {
|
||||||
|
// load quantized bytes
|
||||||
|
// TODO: change back to aligned loads
|
||||||
|
const __m512i xqs0123 = _mm512_loadu_epi64( xqs );
|
||||||
|
const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
|
||||||
|
const __m512i xqs01 = _mm512_and_si512( low_nibble_mask, xqs0123 );
|
||||||
|
// TODO: try srlv/i?
|
||||||
|
const __m512i xqs23 = _mm512_and_si512( low_nibble_mask, _mm512_srli_epi32( xqs0123, 4 ) );
|
||||||
|
const __m512i yqs01 = _mm512_loadu_epi64( yqs );
|
||||||
|
const __m512i yqs23 = _mm512_loadu_epi64( yqs + 2*QK8_0C );
|
||||||
|
|
||||||
|
// load scales
|
||||||
|
const __m512i scale_mask0 = _mm512_set_epi32(1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
const __m512i scale_mask1 = _mm512_set_epi32(3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2);
|
||||||
|
const __m128 xyds = _mm_mul_ps(_mm_load_ps(xds), _mm_load_ps(yds));
|
||||||
|
const __m512 xyds0123 = _mm512_broadcast_f32x4(xyds);
|
||||||
|
const __m512 xyds01 = _mm512_permutevar_ps(xyds0123, scale_mask0);
|
||||||
|
const __m512 xyds23 = _mm512_permutevar_ps(xyds0123, scale_mask1);
|
||||||
|
|
||||||
|
// take dot product of x and y bytes
|
||||||
|
const __m512i plus_8 = _mm512_set1_epi8( 8 );
|
||||||
|
#ifdef __AVX512VNNI__
|
||||||
|
// We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
|
||||||
|
// the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
|
||||||
|
// from each nibble, so they can be negative. So, instead of `(xqs01 - 8) * yqs01`,
|
||||||
|
// we compute `xqs01 * yqs01 - 8 * yqks`.
|
||||||
|
const __m512i zero = _mm512_setzero_epi32();
|
||||||
|
const __m512i yqs01_mul8 = _mm512_dpbusds_epi32( zero, plus_8, yqs01 );
|
||||||
|
const __m512i yqs23_mul8 = _mm512_dpbusds_epi32( zero, plus_8, yqs23 );
|
||||||
|
const __m512i xy01 = _mm512_dpbusds_epi32( zero, xqs01, yqs01 );
|
||||||
|
const __m512i xy23 = _mm512_dpbusds_epi32( zero, xqs23, yqs23 );
|
||||||
|
const __m512i res0_int = _mm512_sub_epi32( xy01, yqs01_mul8 );
|
||||||
|
const __m512i res1_int = _mm512_sub_epi32( xy23, yqs23_mul8 );
|
||||||
|
#else
|
||||||
|
// As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
|
||||||
|
// It has the same catch as VPDPBUSDS: the left operand should be unsigned.
|
||||||
|
// This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
|
||||||
|
// ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
|
||||||
|
const __m512i one = _mm512_set1_epi16( 1 );
|
||||||
|
const __m512i prod_0 = _mm512_maddubs_epi16( xqs01, yqs01 );
|
||||||
|
const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, yqs01 );
|
||||||
|
const __m512i prod_2 = _mm512_maddubs_epi16( xqs23, yqs23 );
|
||||||
|
const __m512i prod_3 = _mm512_maddubs_epi16( plus_8, yqs23 );
|
||||||
|
const __m512i diff0 = _mm512_sub_epi16( prod_0, prod_1 );
|
||||||
|
const __m512i diff1 = _mm512_sub_epi16( prod_2, prod_3 );
|
||||||
|
const __m512i res0_int = _mm512_madd_epi16( diff0, one );
|
||||||
|
const __m512i res1_int = _mm512_madd_epi16( diff1, one );
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
|
||||||
|
const __m512 res0_float = _mm512_cvtepi32_ps( res0_int );
|
||||||
|
const __m512 res1_float = _mm512_cvtepi32_ps( res1_int );
|
||||||
|
|
||||||
|
return _mm512_fmadd_ps( xyds23, res1_float,
|
||||||
|
_mm512_fmadd_ps( xyds01, res0_float, acc ));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
||||||
ggml_float sumf = 0.0;
|
ggml_float sumf = 0.0;
|
||||||
|
|
||||||
|
@ -2999,6 +3095,15 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
|
||||||
|
|
||||||
float sumf = 0.0;
|
float sumf = 0.0;
|
||||||
|
|
||||||
|
#if __AVX512F__
|
||||||
|
// Initialize accumulator with zeros
|
||||||
|
__m512 acc = _mm512_setzero_ps();
|
||||||
|
for (int i = 0; i < nb; i += 4) {
|
||||||
|
acc = dot_q4_0c_fourblocks_avx512(acc, xqs + i*QK4_0/2, xds + i, yqs + i*QK8_0, yds + i);
|
||||||
|
}
|
||||||
|
// Horizontal sum of all lanes of the accumulator
|
||||||
|
sumf = _mm512_reduce_add_ps( acc );
|
||||||
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
for (int i = 0; i < nb/2; i++) {
|
for (int i = 0; i < nb/2; i++) {
|
||||||
const int dst0 = i + i/2*2; // 0, 1, 4, 5, 8, 9, ...
|
const int dst0 = i + i/2*2; // 0, 1, 4, 5, 8, 9, ...
|
||||||
|
@ -3009,23 +3114,25 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void
|
||||||
const float dy0 = yds[dst0];
|
const float dy0 = yds[dst0];
|
||||||
const float dy1 = yds[dst1];
|
const float dy1 = yds[dst1];
|
||||||
|
|
||||||
int sumi0 = 0;
|
// NOTE: having these as plain int triggers a bug with AVX512 on GCC 12.2
|
||||||
int sumi1 = 0;
|
int64_t sumi0 = 0;
|
||||||
|
int64_t sumi1 = 0;
|
||||||
|
|
||||||
for (int l = 0; l < QK4_0; l++) {
|
for (int l = 0; l < QK4_0; l++) {
|
||||||
const uint8_t v0 = xqs[i*QK4_0 + l];
|
const uint8_t v0 = xqs[i*QK4_0 + l];
|
||||||
|
|
||||||
const int i0 = (int8_t) (v0 & 0xf) - 8;
|
const int i0 = (int) (v0 & 0xf) - 8;
|
||||||
const int i1 = (int8_t) (v0 >> 4) - 8;
|
const int i1 = (int) (v0 >> 4) - 8;
|
||||||
|
|
||||||
const int i2 = yqs[dst0*QK4_0 + l];
|
const int i2 = yqs[dst0*QK4_0 + l];
|
||||||
const int i3 = yqs[dst1*QK4_0 + l];
|
const int i3 = yqs[dst1*QK4_0 + l];
|
||||||
|
|
||||||
sumi0 += i0*i2;
|
sumi0 += i0*i2;
|
||||||
sumi1 += i1*i3;
|
sumi1 += i1*i3;
|
||||||
}
|
}
|
||||||
sumf += dx0*dy0*sumi0 + dx1*dy1*sumi1;
|
sumf += dx0*dy0*sumi0 + dx1*dy1*sumi1;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue