ggml : RISC-V vector gemv for q4_0_8x8
This commit is contained in:
parent
66c2c93082
commit
9bfecf4294
1 changed files with 75 additions and 0 deletions
|
@ -991,6 +991,81 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
#elif defined(__riscv_v_intrinsic)
|
||||||
|
if (__riscv_vlenb() >= QK4_0) {
|
||||||
|
const size_t vl = QK4_0;
|
||||||
|
const vuint8m1_t lhs_idx_m1 = __riscv_vand_vx_u8m1(__riscv_vid_v_u8m1(vl), 7, vl);
|
||||||
|
const vuint8m2_t lhs_idx_m2 = __riscv_vcreate_v_u8m1_u8m2(lhs_idx_m1, lhs_idx_m1);
|
||||||
|
const vuint8m2_t lhs_idx_m2_hi = __riscv_vadd_vx_u8m2(lhs_idx_m2, 8, vl);
|
||||||
|
const vuint8m4_t lhs_idx_m4 = __riscv_vcreate_v_u8m2_u8m4(lhs_idx_m2, lhs_idx_m2_hi);
|
||||||
|
const vbool2_t mask0 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000000000FFull, vl / 8)));
|
||||||
|
const vbool2_t mask1 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000000000FF00ull, vl / 8)));
|
||||||
|
const vbool2_t mask2 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000000000FF0000ull, vl / 8)));
|
||||||
|
const vbool2_t mask3 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000FF000000ull, vl / 8)));
|
||||||
|
const vbool2_t mask4 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000FF00000000ull, vl / 8)));
|
||||||
|
const vbool2_t mask5 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000FF0000000000ull, vl / 8)));
|
||||||
|
const vbool2_t mask6 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00FF000000000000ull, vl / 8)));
|
||||||
|
const vbool2_t mask7 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0xFF00000000000000ull, vl / 8)));
|
||||||
|
|
||||||
|
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
|
||||||
|
|
||||||
|
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
|
||||||
|
for (int l = 0; l < nb; l++) {
|
||||||
|
const vint8m1_t lhs_raw_vec = __riscv_vle8_v_i8m1(a_ptr[l].qs, vl);
|
||||||
|
const vint8m4_t lhs_raw_vec_lo = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, lhs_raw_vec);
|
||||||
|
const vint8m4_t lhs_raw_vec_hi = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, __riscv_vslidedown_vx_i8m1(lhs_raw_vec, 16, vl));
|
||||||
|
const vint8m4_t lhs_vec_lo = __riscv_vrgather_vv_i8m4(lhs_raw_vec_lo, lhs_idx_m4, vl * 4);
|
||||||
|
const vint8m4_t lhs_vec_hi = __riscv_vrgather_vv_i8m4(lhs_raw_vec_hi, lhs_idx_m4, vl * 4);
|
||||||
|
|
||||||
|
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
|
||||||
|
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
|
||||||
|
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
|
||||||
|
|
||||||
|
const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4);
|
||||||
|
const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4);
|
||||||
|
const vint16m8_t sumi = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
|
||||||
|
|
||||||
|
const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4);
|
||||||
|
const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m8_i32m1_m(mask7, sumi, iaccz, vl * 4);
|
||||||
|
const vint32m1_t iacc7s = __riscv_vslideup_vx_i32m1(iacc7, iacc7, 1, vl / 4);
|
||||||
|
const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask6, iacc7s, sumi, iaccz, vl * 4);
|
||||||
|
const vint32m1_t iacc6s = __riscv_vslideup_vx_i32m1(iacc6, iacc6, 1, vl / 4);
|
||||||
|
const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask5, iacc6s, sumi, iaccz, vl * 4);
|
||||||
|
const vint32m1_t iacc5s = __riscv_vslideup_vx_i32m1(iacc5, iacc5, 1, vl / 4);
|
||||||
|
const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask4, iacc5s, sumi, iaccz, vl * 4);
|
||||||
|
const vint32m1_t iacc4s = __riscv_vslideup_vx_i32m1(iacc4, iacc4, 1, vl / 4);
|
||||||
|
const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask3, iacc4s, sumi, iaccz, vl * 4);
|
||||||
|
const vint32m1_t iacc3s = __riscv_vslideup_vx_i32m1(iacc3, iacc3, 1, vl / 4);
|
||||||
|
const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask2, iacc3s, sumi, iaccz, vl * 4);
|
||||||
|
const vint32m1_t iacc2s = __riscv_vslideup_vx_i32m1(iacc2, iacc2, 1, vl / 4);
|
||||||
|
const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask1, iacc2s, sumi, iaccz, vl * 4);
|
||||||
|
const vint32m1_t iacc1s = __riscv_vslideup_vx_i32m1(iacc1, iacc1, 1, vl / 4);
|
||||||
|
const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask0, iacc1s, sumi, iaccz, vl * 4);
|
||||||
|
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0, vl / 4);
|
||||||
|
|
||||||
|
// vector version needs Zvfhmin extension
|
||||||
|
const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
|
||||||
|
const float b_scales[8] = {
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[0]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[1]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[2]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[3]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[4]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[5]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[6]),
|
||||||
|
GGML_FP16_TO_FP32(b_ptr[l].d[7])
|
||||||
|
};
|
||||||
|
const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4);
|
||||||
|
const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4);
|
||||||
|
const vfloat32m1_t tmp2 = __riscv_vfmul_vv_f32m1(tmp1, b_scales_vec, vl / 4);
|
||||||
|
sumf = __riscv_vfadd_vv_f32m1(sumf, tmp2, vl / 4);
|
||||||
|
}
|
||||||
|
__riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)
|
||||||
{
|
{
|
||||||
float sumf[8];
|
float sumf[8];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue