diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7b0786036..acdc8108f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4822,6 +4822,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #endif } +// TODO: don't use lookup table for signs static __device__ __forceinline__ float vec_dot_iq3_xs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics @@ -4829,23 +4830,21 @@ static __device__ __forceinline__ float vec_dot_iq3_xs_q8_1( const block_iq3_xs * bq2 = (const block_iq3_xs *) vbq; const int ib32 = iqs; - const uint8_t * q3 = bq2->qs + 8*ib32; - const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32; + const uint8_t * qs = bq2->qs + 8*ib32; const int8_t * q8 = bq8_1[ib32].qs; - uint32_t aux32 = gas[0] | (gas[1] << 16); int sumi = 0; for (int l = 0; l < 4; ++l) { - const uint32_t * grid1 = iq3xs_grid + (q3[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); - const uint32_t * grid2 = iq3xs_grid + (q3[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127)); - const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]); + const uint32_t * grid1 = iq3xs_grid + (qs[2*l+0] | ((bq2->qh[ib32] << (8 - 2*l)) & 256)); + const uint32_t * grid2 = iq3xs_grid + (qs[2*l+1] | ((bq2->qh[ib32] << (7 - 2*l)) & 256)); + uint32_t signs0 = __vcmpeq4(((bq2->signs[4*ib32+l] & 0xf) * 0x01010101) & 0x08040201, 0x08040201); + uint32_t signs1 = __vcmpeq4(((bq2->signs[4*ib32+l] >> 4) * 0x01010101) & 0x08040201, 0x08040201); + const int grid_l = __vsub4(grid1[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid2[0] ^ signs1, signs1); sumi = __dp4a(grid_l, *((int *)q8+0), sumi); sumi = __dp4a(grid_h, *((int *)q8+1), sumi); q8 += 8; - aux32 >>= 7; } - const float d = (float)bq2->d * (0.5f + aux32) * __low2float(bq8_1[ib32].ds) * 0.5f; + const float d = (float)bq2->d * (0.5f + ((bq2->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * __low2float(bq8_1[ib32].ds) * 0.5f; return d * sumi; #else assert(false);