Remove zero vector parameter passing

This commit is contained in:
Srihari-mcw 2024-09-22 20:21:10 -07:00
parent a829583c97
commit 7aee79bda5

View file

@ -139,7 +139,8 @@ static inline __m512i mul_sum_us8_pairs_int_512(const __m512i ax, const __m512i
}
// multiply int8_t, add results pairwise twice and return as int vector
static inline __m512i mul_sum_i8_pairs_int_512(const __m512i x, const __m512i y, const __m512i zero) {
static inline __m512i mul_sum_i8_pairs_int_512(const __m512i x, const __m512i y) {
const __m512i zero = _mm512_setzero_si512();
// Get absolute values of x vectors
const __m512i ax = _mm512_abs_epi8(x);
// Sign the values of the y vectors
@ -2496,8 +2497,6 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
const __m512i m4bexpanded = _mm512_set1_epi8(0x0F);
// Lookup table to convert signed nibbles to signed bytes expanded to 512 bit length
__m512i signextendlutexpanded = _mm512_inserti32x8(_mm512_castsi256_si512(signextendlut), signextendlut, 1);
// Zero vector of length 512 bits
__m512i zero = _mm512_setzero_si512();
// Take group of four block_q8_0x4 structures at each pass of the loop and perform dot product operation
for (; y < anr / 4; y += 4) {
@ -2650,21 +2649,21 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
// Resembles MMLAs into 2x2 matrices in ARM Version
__m512i iacc_mat_00_sp1 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1, zero), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
__m512i iacc_mat_01_sp1 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1, zero), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
__m512i iacc_mat_10_sp1 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1, zero), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
__m512i iacc_mat_11_sp1 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1, zero), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
__m512i iacc_mat_00_sp2 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2, zero), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
__m512i iacc_mat_01_sp2 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2, zero), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
__m512i iacc_mat_10_sp2 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2, zero), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
__m512i iacc_mat_11_sp2 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2, zero), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);
@ -2841,21 +2840,21 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
// The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane
// Resembles MMLAs into 2x2 matrices in ARM Version
__m512i iacc_mat_00_sp1 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1, zero), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1));
__m512i iacc_mat_01_sp1 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1, zero), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1));
__m512i iacc_mat_10_sp1 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1, zero), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1));
__m512i iacc_mat_11_sp1 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1, zero), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1));
__m512i iacc_mat_00_sp2 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2, zero), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2));
__m512i iacc_mat_01_sp2 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2, zero), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int_512(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2));
__m512i iacc_mat_10_sp2 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2, zero), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2));
__m512i iacc_mat_11_sp2 =
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2, zero), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2, zero)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2, zero));
_mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int_512(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int_512(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int_512(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2));
// Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block
__m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2);