diff --git a/otherarch/ggml_rwkv.c b/otherarch/ggml_rwkv.c index 12736cb55..7fa6a7d2f 100644 --- a/otherarch/ggml_rwkv.c +++ b/otherarch/ggml_rwkv.c @@ -1061,50 +1061,50 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in } } } -#elif defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - const float32x4_t vd = vdupq_n_f32(x[i].d); - const float32x4_t vm = vdupq_n_f32(x[i].m); +// #elif defined(__ARM_NEON) +// for (int i = 0; i < nb; i++) { +// const float32x4_t vd = vdupq_n_f32(x[i].d); +// const float32x4_t vm = vdupq_n_f32(x[i].m); - const uint8_t * restrict pp = x[i].qs; +// const uint8_t * restrict pp = x[i].qs; - for (int l = 0; l < QK; l += 16) { - // Load 16x4-bit integers into 8x8-bit integers - const uint8x8_t v8 = vld1_u8(pp + l/2); +// for (int l = 0; l < QK; l += 16) { +// // Load 16x4-bit integers into 8x8-bit integers +// const uint8x8_t v8 = vld1_u8(pp + l/2); - // Expand 4-bit qs to 8-bit bytes - const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); - const uint8x8_t v1 = vshr_n_u8(v8, 4); +// // Expand 4-bit qs to 8-bit bytes +// const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); +// const uint8x8_t v1 = vshr_n_u8(v8, 4); - // Interleave and combine - const uint8x8_t vx_0 = vzip1_u8(v0, v1); - const uint8x8_t vx_1 = vzip2_u8(v0, v1); +// // Interleave and combine +// const uint8x8_t vx_0 = vzip1_u8(v0, v1); +// const uint8x8_t vx_1 = vzip2_u8(v0, v1); - const uint8x16_t vq = vcombine_u8(vx_0, vx_1); +// const uint8x16_t vq = vcombine_u8(vx_0, vx_1); - // convert to 2x uint16x8_t - const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); - const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); +// // convert to 2x uint16x8_t +// const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); +// const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); - // convert to 4x float32x4_t - const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); - const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); - const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); - const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); +// // convert to 4x float32x4_t +// const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); +// const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); +// const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); +// const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); - // multiply by d and add m - const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); - const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); - const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); - const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); +// // multiply by d and add m +// const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); +// const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); +// const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); +// const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); - // Store - vst1q_f32(y + i*QK + l + 0, r0); - vst1q_f32(y + i*QK + l + 4, r1); - vst1q_f32(y + i*QK + l + 8, r2); - vst1q_f32(y + i*QK + l + 12, r3); - } - } +// // Store +// vst1q_f32(y + i*QK + l + 0, r0); +// vst1q_f32(y + i*QK + l + 4, r1); +// vst1q_f32(y + i*QK + l + 8, r2); +// vst1q_f32(y + i*QK + l + 12, r3); +// } +// } #else for (int i = 0; i < nb; i++) { const float d = x[i].d; @@ -1276,56 +1276,56 @@ static void dequantize_row_q4_1_o(const void * restrict vx, float * restrict y, // Restore the outlier y[i * QK + x[i].outlier_index] = ggml_rwkv_half_to_float_reference(x[i].outlier_value); } -#elif defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - const float x_d = ggml_rwkv_half_to_float_reference(x[i].d); - const float x_m = ggml_rwkv_half_to_float_reference(x[i].m); +// #elif defined(__ARM_NEON) +// for (int i = 0; i < nb; i++) { +// const float x_d = ggml_rwkv_half_to_float_reference(x[i].d); +// const float x_m = ggml_rwkv_half_to_float_reference(x[i].m); - const float32x4_t vd = vdupq_n_f32(x_d); - const float32x4_t vm = vdupq_n_f32(x_m); +// const float32x4_t vd = vdupq_n_f32(x_d); +// const float32x4_t vm = vdupq_n_f32(x_m); - const uint8_t * restrict pp = x[i].qs; +// const uint8_t * restrict pp = x[i].qs; - for (int l = 0; l < QK; l += 16) { - // Load 16x4-bit integers into 8x8-bit integers - const uint8x8_t v8 = vld1_u8(pp + l/2); +// for (int l = 0; l < QK; l += 16) { +// // Load 16x4-bit integers into 8x8-bit integers +// const uint8x8_t v8 = vld1_u8(pp + l/2); - // Expand 4-bit qs to 8-bit bytes - const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); - const uint8x8_t v1 = vshr_n_u8(v8, 4); +// // Expand 4-bit qs to 8-bit bytes +// const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f)); +// const uint8x8_t v1 = vshr_n_u8(v8, 4); - // Interleave and combine - const uint8x8_t vx_0 = vzip1_u8(v0, v1); - const uint8x8_t vx_1 = vzip2_u8(v0, v1); +// // Interleave and combine +// const uint8x8_t vx_0 = vzip1_u8(v0, v1); +// const uint8x8_t vx_1 = vzip2_u8(v0, v1); - const uint8x16_t vq = vcombine_u8(vx_0, vx_1); +// const uint8x16_t vq = vcombine_u8(vx_0, vx_1); - // convert to 2x uint16x8_t - const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); - const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); +// // convert to 2x uint16x8_t +// const uint16x8_t vi_0 = vmovl_s8(vget_low_u8 (vq)); +// const uint16x8_t vi_1 = vmovl_s8(vget_high_u8(vq)); - // convert to 4x float32x4_t - const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); - const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); - const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); - const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); +// // convert to 4x float32x4_t +// const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0))); +// const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0))); +// const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1))); +// const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1))); - // multiply by d and add m - const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); - const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); - const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); - const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); +// // multiply by d and add m +// const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd); +// const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd); +// const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd); +// const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); - // Store - vst1q_f32(y + i*QK + l + 0, r0); - vst1q_f32(y + i*QK + l + 4, r1); - vst1q_f32(y + i*QK + l + 8, r2); - vst1q_f32(y + i*QK + l + 12, r3); - } +// // Store +// vst1q_f32(y + i*QK + l + 0, r0); +// vst1q_f32(y + i*QK + l + 4, r1); +// vst1q_f32(y + i*QK + l + 8, r2); +// vst1q_f32(y + i*QK + l + 12, r3); +// } - // Restore the outlier - y[i * QK + x[i].outlier_index] = ggml_rwkv_half_to_float_reference(x[i].outlier_value); - } +// // Restore the outlier +// y[i * QK + x[i].outlier_index] = ggml_rwkv_half_to_float_reference(x[i].outlier_value); +// } #else for (int i = 0; i < nb; i++) { dequantize_row_q4_1_o_reference_single_block(x + i, y + i * QK);