Add additional comments

This commit is contained in:
Srihari-mcw 2024-05-23 08:49:41 -07:00
parent e26fd70dce
commit 983b03ab6a
2 changed files with 6 additions and 1 deletions

View file

@ -4368,6 +4368,7 @@ void ggml_vec_dot_q4_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const
__m128bh xd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, x_delta)));
__m128bh yd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, y_delta)));
// Computes product of delta values from four corresponding blocks
__m256 d = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, xd, yd));
d = _mm256_permute2f128_ps(d ,d, 0);
@ -5902,6 +5903,7 @@ void ggml_vec_dot_q8_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const
__m128bh xd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, x_delta)));
__m128bh yd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, y_delta)));
// Computes product of delta values from four corresponding blocks
__m256 d = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, xd, yd));
d = _mm256_permute2f128_ps(d ,d, 0);

View file

@ -981,6 +981,7 @@ class tinyBLAS_Q0_B16_AVX {
}
#if defined(__AVX512BF16__)
// Templated functions for gemm of dimesnions 4xN
template <int RN>
NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / 4;
@ -1005,6 +1006,7 @@ class tinyBLAS_Q0_B16_AVX {
__m256i avec3 = load(A + lda * (ii + 3) + l);
for (int64_t j = 0; j < RN; ++j) {
__m128bh db = m128bh(_mm_set1_epi16(B[ldb * (jj + j) + l].d));
// Computation of product of delta values for four blocks
__m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db));
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
@ -1056,7 +1058,8 @@ class tinyBLAS_Q0_B16_AVX {
__m256i bvec3 = load(B + ldb * (jj + 3) + l);
for (int64_t i = 0; i < RM; ++i) {
__m128bh da = m128bh(_mm_set1_epi16((A[lda * (ii + i) + l].d)));
__m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db));
// Computation of product of delta values for four blocks
__m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db));
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),