support axpy q4_0 for loop

This commit is contained in:
syx 2023-12-12 15:03:10 +08:00
parent 9975f4aaa7
commit c796dd4c90

View file

@ -2430,7 +2430,7 @@ void ggml_axpy_q4_0_q8_0(const int n, const void * restrict vx, const void * res
assert(nb % 2 == 0);
const block_q4_0 * restrict x = vx;
#if defined(__AVX2__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@ -2491,7 +2491,21 @@ void ggml_axpy_q4_0_q8_0(const int n, const void * restrict vx, const void * res
_mm256_storeu_ps((__m256*)(vz + i*128+96), by);
}
#else
float *res = (float *)vz;
float scale_fp32 = GGML_FP16_TO_FP32(scale);
for (int i = 0; i < nb; i++) {
float result_scale = GGML_FP16_TO_FP32(x[i].d) * scale_fp32;
int offset = i * QK4_0;
for (int j = 0; j < qk/2; ++j) {
const int v0 = (x[i].qs[j] & 0x0F) - 8;
const int v1 = (x[i].qs[j] >> 4) - 8;
res[offset + j] = res[offset + j] + ((float)(v0 * (int)alpha) * result_scale);
res[offset + j + qk/2] = res[offset + j + qk/2] + ((float)(v1 * (int)alpha) * result_scale);
}
}
#endif
}