From 68cfcd4711da27f2a83b0efe2e36f94ccf1e09d4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 28 Jan 2024 12:24:47 +0200 Subject: [PATCH] Faster iq3_xxs and iq2_xs dot products on CUDA --- ggml-cuda.cu | 83 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 66 insertions(+), 17 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f7cd55b13..f8dd4ce60 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1667,6 +1667,43 @@ static const __device__ uint8_t ksigns_iq2xs[128] = { 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, }; +//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +static const __device__ uint64_t ksigns64[128] = { + 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, + 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, + 0xff000000ff000000, 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff, + 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, 0x00000000ffffffff, + 0xff0000ff00000000, 0x000000ff000000ff, 0x000000ff0000ff00, 0xff0000ff0000ffff, + 0x000000ff00ff0000, 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff, + 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, 0x000000ffff00ffff, + 0xff0000ffffff0000, 0x000000ffffff00ff, 0x000000ffffffff00, 0xff0000ffffffffff, + 0xff00ff0000000000, 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff, + 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, 0x0000ff0000ffffff, + 0x0000ff00ff000000, 0xff00ff00ff0000ff, 0xff00ff00ff00ff00, 0x0000ff00ff00ffff, + 0xff00ff00ffff0000, 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff, + 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, 0x0000ffff0000ffff, + 0xff00ffff00ff0000, 0x0000ffff00ff00ff, 0x0000ffff00ffff00, 0xff00ffff00ffffff, + 0xff00ffffff000000, 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff, + 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, 0x0000ffffffffffff, + 0xffff000000000000, 0x00ff0000000000ff, 0x00ff00000000ff00, 0xffff00000000ffff, + 0x00ff000000ff0000, 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff, + 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, 0x00ff0000ff00ffff, + 0xffff0000ffff0000, 0x00ff0000ffff00ff, 0x00ff0000ffffff00, 0xffff0000ffffffff, + 0x00ff00ff00000000, 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff, + 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, 0xffff00ff00ffffff, + 0xffff00ffff000000, 0x00ff00ffff0000ff, 0x00ff00ffff00ff00, 0xffff00ffff00ffff, + 0x00ff00ffffff0000, 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff, + 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, 0x00ffff000000ffff, + 0xffffff0000ff0000, 0x00ffff0000ff00ff, 0x00ffff0000ffff00, 0xffffff0000ffffff, + 0xffffff00ff000000, 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff, + 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, 0x00ffff00ffffffff, + 0xffffffff00000000, 0x00ffffff000000ff, 0x00ffffff0000ff00, 0xffffffff0000ffff, + 0x00ffffff00ff0000, 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff, + 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff, + 0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff, +}; +//#endif + static const __device__ uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128}; inline bool ggml_cuda_supports_mmq(enum ggml_type type) { @@ -4384,6 +4421,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics #if QK_K == 256 const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq; @@ -4394,20 +4432,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( const uint8_t ls2 = bq2->scales[ib32] >> 4; int sumi1 = 0; for (int l = 0; l < 2; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); - const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; - for (int j = 0; j < 8; ++j) { - sumi1 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } + const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]); + sumi1 = __dp4a(grid_l, *((const int *)q8 + 0), sumi1); + sumi1 = __dp4a(grid_h, *((const int *)q8 + 1), sumi1); q8 += 8; } int sumi2 = 0; for (int l = 2; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); - const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; - for (int j = 0; j < 8; ++j) { - sumi2 += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } + const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const int grid_l = __vsub4(grid[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid[1] ^ signs[1], signs[1]); + sumi2 = __dp4a(grid_l, *((const int *)q8 + 0), sumi2); + sumi2 = __dp4a(grid_h, *((const int *)q8 + 1), sumi2); q8 += 8; } const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f; @@ -4416,10 +4456,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( assert(false); return 0.f; #endif +#else + assert(false); + return 0.f; +#endif } static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics #if QK_K == 256 const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; @@ -4430,13 +4475,13 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( uint32_t aux32 = gas[0] | (gas[1] << 16); int sumi = 0; for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); - const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); - const uint8_t signs = ksigns_iq2xs[aux32 & 127]; - for (int j = 0; j < 4; ++j) { - sumi += q8[j+0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1 : 1); - sumi += q8[j+4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1 : 1); - } + const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0]; + const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1]; + const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127)); + const int grid_l = __vsub4(grid1[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid2[0] ^ signs[1], signs[1]); + sumi = __dp4a(grid_l, *((int *)q8+0), sumi); + sumi = __dp4a(grid_h, *((int *)q8+1), sumi); q8 += 8; aux32 >>= 7; } @@ -4446,6 +4491,10 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( assert(false); return 0.f; #endif +#else + assert(false); + return 0.f; +#endif } template