Add initial AVX512 support for dot product on Linux

* Update Makefile to detect AVX512 support and add compiler flags if it's available
 * Based on existing AVX2 implementation, dot product on one 32-value block of 4-bit quantized ints at a time
 * Perform 8 bit -> 16 bit sign extension and multiply+add on 32 values at time instead of 16
 * Use built-in AVX512 horizontal reduce add to get sum at the end
 * Manual unrolling on inner dot product loop to reduce loop counter overhead
This commit is contained in:
Casey Primozic 2023-03-20 04:15:40 -07:00
parent edeba28366
commit a6598801ad
No known key found for this signature in database
GPG key ID: 2A02222DA3425B99
2 changed files with 119 additions and 0 deletions

View file

@ -95,6 +95,38 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
ifneq (,$(findstring sse3,$(SSE3_M)))
CFLAGS += -msse3
endif
AVX512F_M := $(shell grep "avx512f " /proc/cpuinfo)
ifneq (,$(findstring avx512f,$(AVX512F_M)))
CFLAGS += -mavx512f
endif
AVX512BW_M := $(shell grep "avx512bw " /proc/cpuinfo)
ifneq (,$(findstring avx512bw,$(AVX512BW_M)))
CFLAGS += -mavx512bw
endif
AVX512DQ_M := $(shell grep "avx512dq " /proc/cpuinfo)
ifneq (,$(findstring avx512dq,$(AVX512DQ_M)))
CFLAGS += -mavx512dq
endif
AVX512VL_M := $(shell grep "avx512vl " /proc/cpuinfo)
ifneq (,$(findstring avx512vl,$(AVX512VL_M)))
CFLAGS += -mavx512vl
endif
AVX512CD_M := $(shell grep "avx512cd " /proc/cpuinfo)
ifneq (,$(findstring avx512cd,$(AVX512CD_M)))
CFLAGS += -mavx512cd
endif
AVX512ER_M := $(shell grep "avx512er " /proc/cpuinfo)
ifneq (,$(findstring avx512er,$(AVX512ER_M)))
CFLAGS += -mavx512er
endif
AVX512IFMA_M := $(shell grep "avx512ifma " /proc/cpuinfo)
ifneq (,$(findstring avx512ifma,$(AVX512IFMA_M)))
CFLAGS += -mavx512ifma
endif
AVX512PF_M := $(shell grep "avx512pf " /proc/cpuinfo)
ifneq (,$(findstring avx512pf,$(AVX512PF_M)))
CFLAGS += -mavx512pf
endif
else ifeq ($(UNAME_S),Haiku)
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
ifneq (,$(findstring avx,$(AVX1_M)))

87
ggml.c
View file

@ -1417,6 +1417,93 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
#else
#error "not implemented for QK"
#endif
#elif defined(__AVX512F__)
inline __m256i bytesFromNibbles( const uint8_t* rsi ){
// Load 16 bytes from memory
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
// Expand bytes into uint16_t values
__m256i bytes = _mm256_cvtepu8_epi16( tmp );
// Unpack values into individual bytes
const __m256i lowMask = _mm256_set1_epi8( 0xF );
__m256i high = _mm256_andnot_si256( lowMask, bytes );
__m256i low = _mm256_and_si256( lowMask, bytes );
high = _mm256_slli_epi16( high, 4 );
bytes = _mm256_or_si256( low, high );
return bytes;
}
inline __m512 process_one_block(
__m512 acc,
const uint8_t * pd0,
const uint8_t * pd1,
const uint8_t * pb0,
const uint8_t * pb1,
int i
) {
const float * d0_0 = (const float *) (pd0 + i*bs);
const float * d1_0 = (const float *) (pd1 + i*bs);
const uint8_t * restrict p0 = pb0 + (i+0)*bs;
const uint8_t * restrict p1 = pb1 + (i+0)*bs;
// Compute combined scale for the block
float scaleScalar = d0_0[0] * d1_0[0];
__m512 scale = _mm512_set1_ps( scaleScalar );
__m256i bx = bytesFromNibbles( p0 );
__m256i by = bytesFromNibbles( p1 );
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8( bx, off );
by = _mm256_sub_epi8( by, off );
// Sign-extend 16 signed bytes into int16_t
__m512i x32 = _mm512_cvtepi8_epi16( bx );
__m512i y32 = _mm512_cvtepi8_epi16( by );
// Compute products of int16_t integers, add pairwise
__m512i i64 = _mm512_madd_epi16( x32, y32 );
// Convert int32_t to float
__m512 p = _mm512_cvtepi32_ps( i64 );
// Apply the scale, and accumulate
return _mm512_fmadd_ps( scale, p, acc );
}
#if QK == 32
// Initialize accumulator with zeros
__m512 acc0 = _mm512_setzero_ps();
__m512 acc1 = _mm512_setzero_ps();
const int superblock_size = 8;
const int superblock_count = nb / superblock_size;
const int remainder = nb % superblock_size;
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
int i = superblock_ix * superblock_size;
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+0 );
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+1 );
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+2 );
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+3 );
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+4 );
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+5 );
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i+6 );
acc1 = process_one_block( acc1, pd0, pd1, pb0, pb1, i+7 );
}
// Remainders
for (int i = superblock_count * superblock_size; i < nb; ++i) {
acc0 = process_one_block( acc0, pd0, pd1, pb0, pb1, i );
}
// Horizontal sum of all lanes of the accumulator
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
#else
#error "not implemented for QK"
#endif
#elif defined(__AVX2__)
#if QK == 32
const size_t countBlocks = nb;