This commit is contained in:
jon-chuang 2023-04-15 21:57:39 +08:00
parent 00e86b97cc
commit 6bf6543a6a

11
ggml.c
View file

@ -433,12 +433,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
#define QK 32 #define QK 32
// AVX routine provided by GH user jon-chuang // AVX routine provided by GH user jon-chuang
// ref: https://github.com/ggerganov/llama.cpp/issues/956#issuecomment-1508090551 #if __AVX2__ || __AVX512F__
#if false && __AVX2__ || __AVX512F__
// Given A = K X M, B = K X N, compute one row of C = A^TB // Given A = K X M, B = K X N, compute one row of C = A^TB
void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) { void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) {
alignas(32) float res_vec[8]; alignas(32) float res_vec[8];
@ -476,9 +472,7 @@ void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, i
_mm256_maskstore_ps(&C[j], mask_vec, c_vec); _mm256_maskstore_ps(&C[j], mask_vec, c_vec);
} }
} }
#elif __AVX__ #elif __AVX__
// Given A = K X M, B = K X N, compute one row of C = A^TB // Given A = K X M, B = K X N, compute one row of C = A^TB
void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) { void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, int M, int N, int K) {
for (int j = 0; j < N; j += 4) { // Process 4 elements of C's row at a time - 128 / size_of(float) for (int j = 0; j < N; j += 4) { // Process 4 elements of C's row at a time - 128 / size_of(float)
@ -515,12 +509,9 @@ void ggml_mul_row_f32_tall_skinny(const float * A, const float * B, float * C, i
#endif #endif
// AVX routines provided by GH user Const-me // AVX routines provided by GH user Const-me
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600 // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
#if __AVX2__ || __AVX512F__ #if __AVX2__ || __AVX512F__
// Unpack 32 4-bit fields into 32 bytes // Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytesFromNibbles( const uint8_t* rsi ) static inline __m256i bytesFromNibbles( const uint8_t* rsi )