diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1a13d2b8e..b1b1e6512 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -317,34 +317,49 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { } -static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * y, float & result) { +static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { const block_q3_K * x = (const block_q3_K *) vx; - int n = iqs / 128; - int iqsn = iqs - 128*n; - int j = iqsn / 32; - int l = iqsn - 32*j; - int shift = 2*j; - int iss = 2*j + l/16; - int is = 8*n + iss; - int is_shift = 2*(is/4); - uint8_t m = 1 << (4*n + j); + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; - const float d = x[ib].d; + uint32_t aux[3]; + uint32_t utmp[4]; + + // if n is 0, we want to do the lower 128, else the upper 128, + // covering y[l+0], y[l+32], y[l+64], y[l+96] and + // y[l+16], y[l+48], y[l+80], y[l+112] + int n = iqs/128; // 0 or 1 + int r = iqs - 128*n; // 0...120 in steps of 8 + int l = r/8; // 0...15 in steps of 1 + + const float * y = yy + 128*n + l; const uint8_t * q = x[ib].qs + 32*n + l; const uint8_t * hm = x[ib].hmask + l; + const int8_t * s = (const int8_t *)utmp + 8*n; - int8_t us = n == 0 ? (x[ib].scales[iss] & 0xF) | (((x[ib].scales[is+8-2*is_shift] >> is_shift) & 3) << 4) - : (x[ib].scales[iss] >> 4 ) | (((x[ib].scales[is+8-2*is_shift] >> is_shift) & 3) << 4); - float scale = d * (us - 32); + memcpy(aux, x[ib].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + const float dall = x[ib].d; + + const uint8_t m = 1 << (4*n); + + float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4)) + + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4)) + + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4)) + + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4)) + + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4)) + + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4)) + + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4)) + + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4)); + + result = sum * dall; - float sum = 0; - for (int k = 0; k < 8; ++k) { - int8_t ql = (q[k] >> shift) & 3; - sum += y[iqs + k] * (ql - ((hm[k] & m) ? 0 : 4)); - } - result = sum * scale; } static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {