ggml : Added WIP rvv q4_0_8x8 gemm

This commit is contained in:
Xiongchuan Tan 2024-10-22 02:42:35 +08:00
parent 9bfecf4294
commit 3f7fdf24b0

View file

@ -994,18 +994,9 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
#elif defined(__riscv_v_intrinsic)
if (__riscv_vlenb() >= QK4_0) {
const size_t vl = QK4_0;
const vuint8m1_t lhs_idx_m1 = __riscv_vand_vx_u8m1(__riscv_vid_v_u8m1(vl), 7, vl);
const vuint8m2_t lhs_idx_m2 = __riscv_vcreate_v_u8m1_u8m2(lhs_idx_m1, lhs_idx_m1);
const vuint8m2_t lhs_idx_m2_hi = __riscv_vadd_vx_u8m2(lhs_idx_m2, 8, vl);
const vuint8m4_t lhs_idx_m4 = __riscv_vcreate_v_u8m2_u8m4(lhs_idx_m2, lhs_idx_m2_hi);
const vbool2_t mask0 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000000000FFull, vl / 8)));
const vbool2_t mask1 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000000000FF00ull, vl / 8)));
const vbool2_t mask2 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000000000FF0000ull, vl / 8)));
const vbool2_t mask3 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00000000FF000000ull, vl / 8)));
const vbool2_t mask4 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x000000FF00000000ull, vl / 8)));
const vbool2_t mask5 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x0000FF0000000000ull, vl / 8)));
const vbool2_t mask6 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0x00FF000000000000ull, vl / 8)));
const vbool2_t mask7 = __riscv_vreinterpret_v_u16m1_b2(__riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(0xFF00000000000000ull, vl / 8)));
const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF};
const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2);
const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4);
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
@ -1013,11 +1004,12 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
for (int l = 0; l < nb; l++) {
const vint8m1_t lhs_raw_vec = __riscv_vle8_v_i8m1(a_ptr[l].qs, vl);
const vint8m4_t lhs_raw_vec_lo = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, lhs_raw_vec);
const vint8m4_t lhs_raw_vec_hi = __riscv_vset_v_i8m1_i8m4(__riscv_vundefined_i8m4(), 0, __riscv_vslidedown_vx_i8m1(lhs_raw_vec, 16, vl));
const vint8m4_t lhs_vec_lo = __riscv_vrgather_vv_i8m4(lhs_raw_vec_lo, lhs_idx_m4, vl * 4);
const vint8m4_t lhs_vec_hi = __riscv_vrgather_vv_i8m4(lhs_raw_vec_hi, lhs_idx_m4, vl * 4);
const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8));
const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[8], vl / 8));
const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[16], vl / 8));
const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[24], vl / 8));
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4);
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4);
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
@ -1025,25 +1017,34 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
const vint16m8_t sumi_lo = __riscv_vwmul_vv_i16m8(rhs_vec_lo, lhs_vec_lo, vl * 4);
const vint16m8_t sumi_hi = __riscv_vwmul_vv_i16m8(rhs_vec_hi, lhs_vec_hi, vl * 4);
const vint16m8_t sumi = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
const vint16m8_t sumi2 = __riscv_vadd_vv_i16m8(sumi_lo, sumi_hi, vl * 4);
const vint16m4_t sumi2_lo = __riscv_vget_v_i16m8_i16m4(sumi2, 0);
const vint16m4_t sumi2_hi = __riscv_vget_v_i16m8_i16m4(sumi2, 1);
const vint16m4_t sumi = __riscv_vadd_vv_i16m4(sumi2_lo, sumi2_hi, vl * 2);
const vint32m1_t iaccz = __riscv_vmv_v_x_i32m1(0, vl / 4);
const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m8_i32m1_m(mask7, sumi, iaccz, vl * 4);
const vint32m1_t iacc7s = __riscv_vslideup_vx_i32m1(iacc7, iacc7, 1, vl / 4);
const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask6, iacc7s, sumi, iaccz, vl * 4);
const vint32m1_t iacc6s = __riscv_vslideup_vx_i32m1(iacc6, iacc6, 1, vl / 4);
const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask5, iacc6s, sumi, iaccz, vl * 4);
const vint32m1_t iacc5s = __riscv_vslideup_vx_i32m1(iacc5, iacc5, 1, vl / 4);
const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask4, iacc5s, sumi, iaccz, vl * 4);
const vint32m1_t iacc4s = __riscv_vslideup_vx_i32m1(iacc4, iacc4, 1, vl / 4);
const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask3, iacc4s, sumi, iaccz, vl * 4);
const vint32m1_t iacc3s = __riscv_vslideup_vx_i32m1(iacc3, iacc3, 1, vl / 4);
const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask2, iacc3s, sumi, iaccz, vl * 4);
const vint32m1_t iacc2s = __riscv_vslideup_vx_i32m1(iacc2, iacc2, 1, vl / 4);
const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask1, iacc2s, sumi, iaccz, vl * 4);
const vint32m1_t iacc1s = __riscv_vslideup_vx_i32m1(iacc1, iacc1, 1, vl / 4);
const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m8_i32m1_tum(mask0, iacc1s, sumi, iaccz, vl * 4);
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0, vl / 4);
const vint32m1_t iacc7 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
const vint32m1_t iacc6 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
const vint32m1_t iacc5 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
const vint32m1_t iacc4 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
const vint32m1_t iacc3 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
const vint32m1_t iacc2 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
const vint32m1_t iacc1 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
sumi = __riscv_vslideup_vx_i16m4(sumi, sumi, 8, vl * 2);
const vint32m1_t iacc0 = __riscv_vwredsum_vs_i16m4_i32m1_m(mask7, sumi, iaccz, vl * 2);
const vint32m1_t iacc6s = __riscv_vslide1up_vx_i32m1(iacc7, __riscv_vmv_x_s_i32m1_i32(iacc6), vl / 4);
const vint32m1_t iacc5s = __riscv_vslide1up_vx_i32m1(iacc6s, __riscv_vmv_x_s_i32m1_i32(iacc5), vl / 4);
const vint32m1_t iacc4s = __riscv_vslide1up_vx_i32m1(iacc5s, __riscv_vmv_x_s_i32m1_i32(iacc4), vl / 4);
const vint32m1_t iacc3s = __riscv_vslide1up_vx_i32m1(iacc4s, __riscv_vmv_x_s_i32m1_i32(iacc3), vl / 4);
const vint32m1_t iacc2s = __riscv_vslide1up_vx_i32m1(iacc3s, __riscv_vmv_x_s_i32m1_i32(iacc2), vl / 4);
const vint32m1_t iacc1s = __riscv_vslide1up_vx_i32m1(iacc2s, __riscv_vmv_x_s_i32m1_i32(iacc1), vl / 4);
const vint32m1_t iacc0s = __riscv_vslide1up_vx_i32m1(iacc1s, __riscv_vmv_x_s_i32m1_i32(iacc0), vl / 4);
const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(iacc0s, vl / 4);
// vector version needs Zvfhmin extension
const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d);
@ -3246,6 +3247,50 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
}
}
return;
}
#elif defined(__riscv_v_intrinsic)
if (__riscv_vlenb() >= QK4_0) {
const size_t vl = QK4_0;
const uint8_t mask_bytes[] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xFF};
const vbool4_t mask7 = __riscv_vlm_v_b4(mask_bytes, vl * 2);
for (int y = 0; y < nr / 4; y++) {
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
// for (int m = 0; m < 4; m++) {
// for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
// }
vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4);
for (int l = 0; l < nb; l++) {
const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4);
const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4);
const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4);
{
const vint8m1_t lhs_0_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[0], vl / 8));
const vint8m1_t lhs_1_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[32], vl / 8));
const vint8m1_t lhs_2_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[64], vl / 8));
const vint8m1_t lhs_3_4 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(*(int64_t *)&a_ptr[l].qs[96], vl / 8));
const vint8m4_t lhs_vec_lo = __riscv_vcreate_v_i8m1_i8m4(lhs_0_4, lhs_0_4, lhs_1_4, lhs_1_4);
const vint8m4_t lhs_vec_hi = __riscv_vcreate_v_i8m1_i8m4(lhs_2_4, lhs_2_4, lhs_3_4, lhs_3_4);
}
}
// for (int m = 0; m < 4; m++) {
// for (int j = 0; j < ncols_interleaved; j++)
// s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
// }
__riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4);
__riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4);
__riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4);
__riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4);
}
}
return;
}
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__)