add neon impl

This commit is contained in:
slaren 2024-04-26 03:26:39 +02:00
parent cf4fa0c193
commit 55dec7c4a8

View file

@ -12412,7 +12412,7 @@ static bool isnan_fp16(ggml_fp16_t f) {
return (f & 0x7c00) == 0x7c00 && (f & 0x03ff) != 0;
}
static inline bool validate_fp16(ggml_fp16_t f, size_t i) {
static bool validate_fp16(ggml_fp16_t f, size_t i) {
if (isinf_fp16(f)) {
fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i);
return false;
@ -12448,7 +12448,6 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
return false;
}
// size check
if (nbytes % ggml_type_size(type) != 0) {
fprintf(stderr, "%s: invalid size %zu for type %d\n", __func__, nbytes, type);
return false;
@ -12461,7 +12460,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
const ggml_fp16_t * f = (const ggml_fp16_t *) data;
size_t i = 0;
#ifdef __AVX2__
#if defined(__AVX2__)
for (; i + 15 < nb; i += 16) {
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi16(0x7c00));
@ -12476,6 +12475,21 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
GGML_UNREACHABLE();
}
}
#elif defined(__ARM_NEON)
for (; i + 7 < nb; i += 8) {
uint16x8_t v = vld1q_u16(f + i);
uint16x8_t vexp = vandq_u16(v, vdupq_n_u16(0x7c00));
uint16x8_t cmp = vceqq_u16(vexp, vdupq_n_u16(0x7c00));
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(cmp, 4)), 0);
if (mask) {
for (size_t j = 0; j < 8; ++j) {
if (!validate_fp16(f[i + j], i + j)) {
return false;
}
}
GGML_UNREACHABLE();
}
}
#endif
for (; i < nb; ++i) {
if (!validate_fp16(f[i], i)) {
@ -12487,7 +12501,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
const float * f = (const float *) data;
size_t i = 0;
#ifdef __AVX2__
#if defined(__AVX2__)
for (; i + 7 < nb; i += 8) {
__m256i v = _mm256_loadu_si256((const __m256i *)(f + i));
__m256i vexp = _mm256_and_si256(v, _mm256_set1_epi32(0x7f800000));
@ -12502,6 +12516,21 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
GGML_UNREACHABLE();
}
}
#elif defined(__ARM_NEON)
for (; i + 3 < nb; i += 4) {
uint32x4_t v = vld1q_u32((const uint32_t *)f + i);
uint32x4_t vexp = vandq_u32(v, vdupq_n_u32(0x7f800000));
uint32x4_t cmp = vceqq_u32(vexp, vdupq_n_u32(0x7f800000));
uint64_t mask = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u32(cmp, 8)), 0);
if (mask) {
for (size_t j = 0; j < 4; ++j) {
if (!validate_float(f[i + j], i + j)) {
return false;
}
}
GGML_UNREACHABLE();
}
}
#endif
for (; i < nb; ++i) {
if (!validate_float(f[i], i)) {