q2_K sc_high

This commit is contained in:
JohannesGaessler 2023-07-28 19:27:44 +02:00
parent 58daf95aa3
commit abed446346

View file

@ -1733,7 +1733,11 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl(
const int vi = (v >> (2*i)) & 0x03030303; const int vi = (v >> (2*i)) & 0x03030303;
sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
sumf_m += d8[i] * (__dp4a(0x01010101, u[i], 0) * (sc >> 4)); // multiply constant q2_K part with sum of q8_1 values
int sc_high = sc >> 4;
sc_high |= sc_high << 8;
sc_high |= sc_high << 16;
sumf_m += d8[i] * __dp4a(sc_high, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
} }
const float2 dmf = __half22float2(dm); const float2 dmf = __half22float2(dm);
@ -1795,6 +1799,10 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
__builtin_assume(i < 2*WARP_SIZE);
__builtin_assume(j < WARP_SIZE);
__builtin_assume(k < WARP_SIZE);
const int kbx = k / QI2_K; const int kbx = k / QI2_K;
const int kqsx = k % QI2_K; const int kqsx = k % QI2_K;