diff --git a/ggml-quants.c b/ggml-quants.c index 7b40baa03..a000b352c 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -12412,7 +12412,7 @@ static bool isnan_fp16(ggml_fp16_t f) { return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0; } -static inline bool validate_fp16(ggml_fp16_t f, size_t i) { +static bool validate_fp16(ggml_fp16_t f, size_t i) { if (isinf_fp16(f)) { fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i); return false; @@ -12448,7 +12448,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte return false; } - // size check if (nbytes % ggml_type_size(type) != 0) { fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type); return false; @@ -12461,7 +12460,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { const ggml_fp16_t * f = (const ggml_fp16_t *) data; size_t i = 0; -#ifdef __AVX2__ +#if defined(__AVX2__) for (; i + 15 < nb; i += 16) { __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00)); @@ -12476,6 +12475,21 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte GGML_UNREACHABLE(); } } +#elif defined(__ARM_NEON) + for (; i + 7 < nb; i += 8) { + uint16x8_t v = vld1q_u16(f + i); + uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00)); + uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0); + if (mask) { + for (size_t j = 0; j < 8; ++j) { + if (!validate_fp16(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } #endif for (; i < nb; ++i) { if (!validate_fp16(f[i], i)) { @@ -12487,7 +12501,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { const float * f = (const float *) data; size_t i = 0; -#ifdef __AVX2__ +#if defined(__AVX2__) for (; i + 7 < nb; i += 8) { __m256i v = _mm256_loadu_si256((const __m256i *)(f + i)); __m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000)); @@ -12502,6 +12516,21 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte GGML_UNREACHABLE(); } } +#elif defined(__ARM_NEON) + for (; i + 3 < nb; i += 4) { + uint32x4_t v = vld1q_u32((const uint32_t *)f + i); + uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000)); + uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000)); + uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u32(cmp, 8)), 0); + if (mask) { + for (size_t j = 0; j < 4; ++j) { + if (!validate_float(f[i + j], i + j)) { + return false; + } + } + GGML_UNREACHABLE(); + } + } #endif for (; i < nb; ++i) { if (!validate_float(f[i], i)) {