Fix dequantization - forgot to interleave the quants

This commit is contained in:
Georgi Gerganov 2023-03-25 19:31:23 +02:00
parent 04be5b0ba4
commit b83ddbd768
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

29
ggml.c
View file

@ -794,6 +794,8 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
const uint8_t * restrict pp = pb + i*bs;
const float32x4_t vd = vdupq_n_f32(d);
for (int l = 0; l < QK; l += 16) {
// Load 16x4-bit integers into 8x8-bit integers
const uint8x8_t v8 = vld1_u8(pp + l/2);
@ -802,8 +804,6 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
const uint8x8_t v1 = vshr_n_u8(v8, 4);
/*printf("v0: %4d %4d %4d %4d %4d %4d %4d %4d\n", v0[0], v0[1], v0[2], v0[3], v0[4], v0[5], v0[6], v0[7]);*/
// Convert to signed 8-bit integers
const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
@ -812,28 +812,27 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
/*printf("vb_0: %4d %4d %4d %4d %4d %4d %4d %4d\n", vb_0[0], vb_0[1], vb_0[2], vb_0[3], vb_0[4], vb_0[5], vb_0[6], vb_0[7]);*/
/*printf("vb_1: %4d %4d %4d %4d %4d %4d %4d %4d\n", vb_1[0], vb_1[1], vb_1[2], vb_1[3], vb_1[4], vb_1[5], vb_1[6], vb_1[7]);*/
// Interleave and combine
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
// Convert to 16-bit integers
const int16x8_t vi_0 = vmovl_s8(vb_0);
const int16x8_t vi_1 = vmovl_s8(vb_1);
const int8x16_t vq = vcombine_s8(vx_0, vx_1);
/*printf("vi_0: %4d %4d %4d %4d %4d %4d %4d %4d\n", vi_0[0], vi_0[1], vi_0[2], vi_0[3], vi_0[4], vi_0[5], vi_0[6], vi_0[7]);*/
// convert to 2x int16x8_t
const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
// Convert to 32-bit floats
// convert to 4x float32x4_t
const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
/*printf("vf: %4.1f %4.1f %4.1f %4.1f %4.1f %4.1f %4.1f %4.1f\n", vf_0[0], vf_0[1], vf_0[2], vf_0[3], vf_1[0], vf_1[1], vf_1[2], vf_1[3]);*/
// Multiply by d
const float32x4_t r0 = vmulq_n_f32(vf_0, d);
const float32x4_t r1 = vmulq_n_f32(vf_1, d);
const float32x4_t r2 = vmulq_n_f32(vf_2, d);
const float32x4_t r3 = vmulq_n_f32(vf_3, d);
const float32x4_t r0 = vmulq_f32(vf_0, vd);
const float32x4_t r1 = vmulq_f32(vf_1, vd);
const float32x4_t r2 = vmulq_f32(vf_2, vd);
const float32x4_t r3 = vmulq_f32(vf_3, vd);
// Store
vst1q_f32(y + i*QK + l + 0, r0);