From c806db318d202e7a5fb047e325946a52a88246eb Mon Sep 17 00:00:00 2001 From: slaren Date: Thu, 25 Apr 2024 02:10:11 +0200 Subject: [PATCH] improve fp16 validation performance --- ggml-quants.c | 70 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/ggml-quants.c b/ggml-quants.c index ae019c419..7b40baa03 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -12404,14 +12404,32 @@ static bool validate_float(float f, size_t i) { return true; } -static bool validate_f16(ggml_fp16_t f, size_t i) { - return validate_float(GGML_FP16_TO_FP32(f), i); +static bool isinf_fp16(ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) == 0; +} + +static bool isnan_fp16(ggml_fp16_t f) { + return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0; +} + +static inline bool validate_fp16(ggml_fp16_t f, size_t i) { + if (isinf_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; } #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \ const type * q = (const type *) (data); \ for (size_t i = 0; i < (nb); ++i) { \ - if (!validate_f16(q[i].d, i)) { \ + if (!validate_fp16(q[i].d, i)) { \ return false; \ } \ } @@ -12419,7 +12437,7 @@ static bool validate_f16(ggml_fp16_t f, size_t i) { #define VALIDATE_ROW_DATA_DM_F16_IMPL(type, data, nb, d, m) \ const type * q = (const type *) (data); \ for (size_t i = 0; i < (nb); ++i) { \ - if (!validate_f16(q[i].d, i) || !validate_f16(q[i].m, i)) { \ + if (!validate_fp16(q[i].d, i) || !validate_fp16(q[i].m, i)) { \ return false; \ } \ } @@ -12436,14 +12454,31 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte return false; } - size_t nb = nbytes/ggml_type_size(type); + const size_t nb = nbytes/ggml_type_size(type); switch (type) { case GGML_TYPE_F16: { const ggml_fp16_t * f = (const ggml_fp16_t *) data; - for (size_t i = 0; i < nb; ++i) { - if (!validate_f16(f[i], i)) { + size_t i = 0; +#ifdef __AVX2__ + for (; i + 15 < nb; i += 16) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00)); + __m256i cmp = _mm256_cmpeq_epi16(vexp, _mm256_set1_epi16(0x7c00)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 16; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { + if (!validate_fp16(f[i], i)) { return false; } } @@ -12451,7 +12486,24 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte case GGML_TYPE_F32: { const float * f = (const float *) data; - for (size_t i = 0; i < nb; ++i) { + size_t i = 0; +#ifdef __AVX2__ + for (; i + 7 < nb; i += 8) { + __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); + __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000)); + __m256i cmp = _mm256_cmpeq_epi32(vexp, _mm256_set1_epi32(0x7f800000)); + int mask = _mm256_movemask_epi8(cmp); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } +#endif + for (; i < nb; ++i) { if (!validate_float(f[i], i)) { return false; } @@ -12539,7 +12591,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte iq1m_scale_t scale; const uint16_t * sc = (const uint16_t *)q[i].scales; scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - if (!validate_f16(scale.f16, i)) { + if (!validate_fp16(scale.f16, i)) { return false; } #endif