Edit commments

This commit is contained in:
Srihari-mcw 2024-09-23 01:33:48 -07:00
parent 7436d52922
commit 14d2abb8eb

View file

@ -121,7 +121,7 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__)
#if defined(__AVX512F__) #if defined(__AVX512F__)
// add int16_t pairwise and return as int vector // add int16_t pairwise and return as 512 bit int vector
static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) { static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) {
const __m512i ones = _mm512_set1_epi16(1); const __m512i ones = _mm512_set1_epi16(1);
return _mm512_madd_epi16(ones, x); return _mm512_madd_epi16(ones, x);
@ -138,7 +138,7 @@ static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i
#endif #endif
} }
// multiply int8_t, add results pairwise twice and return as int vector // multiply int8_t, add results pairwise twice and return as 512 bit int vector
static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) { static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) {
const __m512i zero = _mm512_setzero_si512(); const __m512i zero = _mm512_setzero_si512();
// Get absolute values of x vectors // Get absolute values of x vectors
@ -150,6 +150,7 @@ static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y
} }
#endif #endif
// add int16_t pairwise and return as 256 bit int vector
static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
const __m256i ones = _mm256_set1_epi16(1); const __m256i ones = _mm256_set1_epi16(1);
return _mm256_madd_epi16(ones, x); return _mm256_madd_epi16(ones, x);
@ -167,7 +168,7 @@ static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i
} }
// Integer variant of the function defined in ggml-quants.c // Integer variant of the function defined in ggml-quants.c
// multiply int8_t, add results pairwise twice and return as int vector // multiply int8_t, add results pairwise twice and return as 256 bit int vector
static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) { static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) {
#if __AVXVNNIINT8__ #if __AVXVNNIINT8__
const __m256i zero = _mm256_setzero_si256(); const __m256i zero = _mm256_setzero_si256();