ggml : Added initial implementation of rvv gemm
This commit is contained in:
parent
3f7fdf24b0
commit
238cd6674e
1 changed files with 160 additions and 47 deletions
|
@ -994,9 +994,6 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|||
#elif defined(__riscv_v_intrinsic)
|
||||
if (__riscv_vlenb() >= QK4_0) {
|
||||
const size_t vl = QK4_0;
|
||||
const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF};
|
||||
const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2);
|
||||
const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4);
|
||||
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
|
@ -1004,12 +1001,12 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|||
|
||||
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8));
|
||||
const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[8], vl / 8));
|
||||
const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[16], vl / 8));
|
||||
const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[24], vl / 8));
|
||||
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4);
|
||||
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4);
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[0], 0, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[8], 0, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[16], 0, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[24], 0, vl / 4));
|
||||
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8);
|
||||
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8);
|
||||
|
||||
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
||||
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
||||
|
@ -1022,29 +1019,19 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|||
const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1);
|
||||
const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2);
|
||||
|
||||
const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
|
||||
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
|
||||
const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
|
||||
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
|
||||
const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
|
||||
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
|
||||
const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
|
||||
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
|
||||
const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
|
||||
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
|
||||
const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
|
||||
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
|
||||
const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
|
||||
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
|
||||
const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
|
||||
const vint32m1_t iacc6s = __riscv_vslide1up_vx_i32m1(iacc7, __riscv_vmv_x_s_i32m1_i32(iacc6), vl / 4);
|
||||
const vint32m1_t iacc5s = __riscv_vslide1up_vx_i32m1(iacc6s, __riscv_vmv_x_s_i32m1_i32(iacc5), vl / 4);
|
||||
const vint32m1_t iacc4s = __riscv_vslide1up_vx_i32m1(iacc5s, __riscv_vmv_x_s_i32m1_i32(iacc4), vl / 4);
|
||||
const vint32m1_t iacc3s = __riscv_vslide1up_vx_i32m1(iacc4s, __riscv_vmv_x_s_i32m1_i32(iacc3), vl / 4);
|
||||
const vint32m1_t iacc2s = __riscv_vslide1up_vx_i32m1(iacc3s, __riscv_vmv_x_s_i32m1_i32(iacc2), vl / 4);
|
||||
const vint32m1_t iacc1s = __riscv_vslide1up_vx_i32m1(iacc2s, __riscv_vmv_x_s_i32m1_i32(iacc1), vl / 4);
|
||||
const vint32m1_t iacc0s = __riscv_vslide1up_vx_i32m1(iacc1s, __riscv_vmv_x_s_i32m1_i32(iacc0), vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0s, vl / 4);
|
||||
const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi));
|
||||
const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl);
|
||||
const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2);
|
||||
const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2);
|
||||
const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2);
|
||||
const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2);
|
||||
const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4);
|
||||
const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4));
|
||||
const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4));
|
||||
const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4);
|
||||
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4);
|
||||
|
||||
// vector version needs Zvfhmin extension
|
||||
const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
|
||||
|
@ -3252,16 +3239,11 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|||
#elif defined(__riscv_v_intrinsic)
|
||||
if (__riscv_vlenb() >= QK4_0) {
|
||||
const size_t vl = QK4_0;
|
||||
const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF};
|
||||
const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2);
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||
// for (int m = 0; m < 4; m++) {
|
||||
// for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
|
||||
// }
|
||||
vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||
|
@ -3271,19 +3253,150 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
|||
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
||||
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
||||
|
||||
// vector version needs Zvfhmin extension
|
||||
const float a_scales[4] = {
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[0]),
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[1]),
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[2]),
|
||||
GGML_FP16_TO_FP32(a_ptr[l].d[3])
|
||||
};
|
||||
const float b_scales[8] = {
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
||||
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
||||
};
|
||||
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
||||
|
||||
vint16m4_t sumi_l0;
|
||||
{
|
||||
const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8));
|
||||
const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[32], vl / 8));
|
||||
const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[64], vl / 8));
|
||||
const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[96], vl / 8));
|
||||
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4);
|
||||
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4);
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[0], 0, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[32], 0, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[64], 0, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[96], 0, vl / 4));
|
||||
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8);
|
||||
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8);
|
||||
const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4);
|
||||
const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4);
|
||||
const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
|
||||
const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0);
|
||||
const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1);
|
||||
const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2);
|
||||
sumi_l0 = sumi;
|
||||
}
|
||||
__asm__ __volatile__("" ::: "memory");
|
||||
|
||||
vint16m4_t sumi_l1;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[8], 0, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[40], 0, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[72], 0, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[104], 0, vl / 4));
|
||||
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8);
|
||||
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8);
|
||||
const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4);
|
||||
const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4);
|
||||
const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
|
||||
const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0);
|
||||
const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1);
|
||||
const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2);
|
||||
sumi_l1 = sumi;
|
||||
}
|
||||
__asm__ __volatile__("" ::: "memory");
|
||||
|
||||
{
|
||||
const vint16m8_t sumi = __riscv_vcreate_v_i16m4_i16m8(sumi_l0, sumi_l1);
|
||||
const vuint32m8_t sumi_i32 = __riscv_vreinterpret_v_i32m8_u32m8(__riscv_vreinterpret_v_i16m8_i32m8(sumi));
|
||||
const vuint16m4_t sumi_h2_0 = __riscv_vnsrl_wx_u16m4(sumi_i32, 0, vl * 2);
|
||||
const vuint16m4_t sumi_h2_1 = __riscv_vnsrl_wx_u16m4(sumi_i32, 16, vl * 2);
|
||||
const vuint16m4_t sumi_h2 = __riscv_vadd_vv_u16m4(sumi_h2_0, sumi_h2_1, vl * 2);
|
||||
const vuint32m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m4_u32m4(sumi_h2);
|
||||
const vuint16m2_t sumi_h4_0 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h4_1 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h4 = __riscv_vadd_vv_u16m2(sumi_h4_0, sumi_h4_1, vl);
|
||||
const vuint32m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h4);
|
||||
const vint16m1_t sumi_h8_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 0, vl / 2));
|
||||
const vint16m1_t sumi_h8_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 16, vl / 2));
|
||||
const vint32m2_t sumi_h8 = __riscv_vwadd_vv_i32m2(sumi_h8_0, sumi_h8_1, vl / 2);
|
||||
const vfloat32m2_t facc = __riscv_vfcvt_f_x_v_f32m2(sumi_h8, vl / 2);
|
||||
|
||||
const vfloat32m1_t facc0 = __riscv_vget_v_f32m2_f32m1(facc, 0);
|
||||
const vfloat32m1_t tmp01 = __riscv_vfmul_vf_f32m1(facc0, a_scales[0], vl / 4);
|
||||
const vfloat32m1_t tmp02 = __riscv_vfmul_vv_f32m1(tmp01, b_scales_vec, vl / 4);
|
||||
sumf0 = __riscv_vfadd_vv_f32m1(sumf0, tmp02, vl / 4);
|
||||
const vfloat32m1_t facc1 = __riscv_vget_v_f32m2_f32m1(facc, 1);
|
||||
const vfloat32m1_t tmp11 = __riscv_vfmul_vf_f32m1(facc1, a_scales[1], vl / 4);
|
||||
const vfloat32m1_t tmp12 = __riscv_vfmul_vv_f32m1(tmp11, b_scales_vec, vl / 4);
|
||||
sumf1 = __riscv_vfadd_vv_f32m1(sumf1, tmp12, vl / 4);
|
||||
}
|
||||
__asm__ __volatile__("" ::: "memory");
|
||||
|
||||
vint16m4_t sumi_l2;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[16], 0, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[48], 0, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[80], 0, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[112], 0, vl / 4));
|
||||
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8);
|
||||
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8);
|
||||
const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4);
|
||||
const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4);
|
||||
const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
|
||||
const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0);
|
||||
const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1);
|
||||
const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2);
|
||||
sumi_l2 = sumi;
|
||||
}
|
||||
__asm__ __volatile__("" ::: "memory");
|
||||
|
||||
vint16m4_t sumi_l3;
|
||||
{
|
||||
const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[24], 0, vl / 4));
|
||||
const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[56], 0, vl / 4));
|
||||
const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[88], 0, vl / 4));
|
||||
const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vlse64_v_i64m2((const int64_t *)&a_ptr[l].qs[120], 0, vl / 4));
|
||||
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m2_i8m4(lhs_0_8, lhs_1_8);
|
||||
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m2_i8m4(lhs_2_8, lhs_3_8);
|
||||
const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4);
|
||||
const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4);
|
||||
const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
|
||||
const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0);
|
||||
const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1);
|
||||
const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2);
|
||||
sumi_l3 = sumi;
|
||||
}
|
||||
__asm__ __volatile__("" ::: "memory");
|
||||
|
||||
{
|
||||
const vint16m8_t sumi = __riscv_vcreate_v_i16m4_i16m8(sumi_l2, sumi_l3);
|
||||
const vuint32m8_t sumi_i32 = __riscv_vreinterpret_v_i32m8_u32m8(__riscv_vreinterpret_v_i16m8_i32m8(sumi));
|
||||
const vuint16m4_t sumi_h2_0 = __riscv_vnsrl_wx_u16m4(sumi_i32, 0, vl * 2);
|
||||
const vuint16m4_t sumi_h2_1 = __riscv_vnsrl_wx_u16m4(sumi_i32, 16, vl * 2);
|
||||
const vuint16m4_t sumi_h2 = __riscv_vadd_vv_u16m4(sumi_h2_0, sumi_h2_1, vl * 2);
|
||||
const vuint32m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m4_u32m4(sumi_h2);
|
||||
const vuint16m2_t sumi_h4_0 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 0, vl);
|
||||
const vuint16m2_t sumi_h4_1 = __riscv_vnsrl_wx_u16m2(sumi_h2_i32, 16, vl);
|
||||
const vuint16m2_t sumi_h4 = __riscv_vadd_vv_u16m2(sumi_h4_0, sumi_h4_1, vl);
|
||||
const vuint32m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h4);
|
||||
const vint16m1_t sumi_h8_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 0, vl / 2));
|
||||
const vint16m1_t sumi_h8_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(sumi_h4_i32, 16, vl / 2));
|
||||
const vint32m2_t sumi_h8 = __riscv_vwadd_vv_i32m2(sumi_h8_0, sumi_h8_1, vl / 2);
|
||||
const vfloat32m2_t facc = __riscv_vfcvt_f_x_v_f32m2(sumi_h8, vl / 2);
|
||||
|
||||
const vfloat32m1_t facc0 = __riscv_vget_v_f32m2_f32m1(facc, 0);
|
||||
const vfloat32m1_t tmp01 = __riscv_vfmul_vf_f32m1(facc0, a_scales[2], vl / 4);
|
||||
const vfloat32m1_t tmp02 = __riscv_vfmul_vv_f32m1(tmp01, b_scales_vec, vl / 4);
|
||||
sumf2 = __riscv_vfadd_vv_f32m1(sumf2, tmp02, vl / 4);
|
||||
const vfloat32m1_t facc1 = __riscv_vget_v_f32m2_f32m1(facc, 1);
|
||||
const vfloat32m1_t tmp11 = __riscv_vfmul_vf_f32m1(facc1, a_scales[3], vl / 4);
|
||||
const vfloat32m1_t tmp12 = __riscv_vfmul_vv_f32m1(tmp11, b_scales_vec, vl / 4);
|
||||
sumf3 = __riscv_vfadd_vv_f32m1(sumf3, tmp12, vl / 4);
|
||||
}
|
||||
}
|
||||
// for (int m = 0; m < 4; m++) {
|
||||
// for (int j = 0; j < ncols_interleaved; j++)
|
||||
// s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
// }
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
|
||||
__riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue