From e48fd74b45c5d39f88b15302f87047364df0ecf5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 4 Jul 2024 21:19:09 +0300 Subject: [PATCH 1/2] ggml : remove k_quants_per_iteration macro ggml-ci --- Makefile | 8 - ggml/CMakeLists.txt | 2 - ggml/src/CMakeLists.txt | 2 - ggml/src/ggml-cuda/dmmv.cu | 120 ++++--------- ggml/src/ggml-sycl/dmmv.cpp | 163 ++++++------------ ggml/src/ggml-sycl/presets.hpp | 6 - ggml/src/ggml-vulkan.cpp | 6 - ggml/src/vulkan-shaders/mul_mat_vec_base.comp | 2 - ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp | 12 +- ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp | 12 +- ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp | 35 +--- ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp | 25 +-- 12 files changed, 110 insertions(+), 283 deletions(-) diff --git a/Makefile b/Makefile index 332496cfc..3f0d91b85 100644 --- a/Makefile +++ b/Makefile @@ -688,12 +688,6 @@ ifdef GGML_CUDA_DMMV_F16 MK_NVCCFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_DMMV_F16 -ifdef GGML_CUDA_KQUANTS_ITER - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) -else - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 -endif - ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) else @@ -810,7 +804,6 @@ ifdef GGML_HIPBLAS GGML_CUDA_DMMV_X ?= 32 GGML_CUDA_MMV_Y ?= 1 - GGML_CUDA_KQUANTS_ITER ?= 2 MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUDA @@ -827,7 +820,6 @@ endif # GGML_HIP_UMA HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS)) HIPFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) HIPFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) - HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) ifdef GGML_CUDA_FORCE_DMMV HIPFLAGS += -DGGML_CUDA_FORCE_DMMV diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index cc1685884..c55d889f3 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -120,8 +120,6 @@ option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) -set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING - "ggml: iters./thread per block for Q2_K/Q6_K") set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index ff84b9bb5..843c2a68f 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -323,7 +323,6 @@ if (GGML_CUDA) add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_USE_GRAPHS) @@ -471,7 +470,6 @@ if (GGML_HIPBLAS) add_compile_definitions(GGML_USE_HIPBLAS) add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) if (GGML_HIP_UMA) add_compile_definitions(GGML_HIP_UMA) diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 96a5adef5..a45a92224 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -2,16 +2,7 @@ #include "dequantize.cuh" #include "convert.cuh" -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 2 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -22,15 +13,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, float tmp = 0; // partial sum for thread in warp - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; // 0,1 - const int step = 16/K_QUANTS_PER_ITERATION; + const int step = 8; const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int in = tid - step*im; // 0...15 or 0...7 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int l0 = 2*in; // 0...14 in steps of 2 const int q_offset = 32*im + l0; const int s_offset = 8*im; const int y_offset = 128*im + l0; @@ -39,7 +30,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const uint8_t * d = (const uint8_t *)aux; const uint8_t * m = (const uint8_t *)(aux + 2); - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; @@ -54,7 +45,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, aux[3] = (a[1] >> 4) & 0x0f0f0f0f; float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) @@ -94,17 +85,17 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + const int tid = threadIdx.x/2; // 0...16 + const int ix = threadIdx.x%2; // 0,1 - const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop - const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0....15 or 0...7 + const int n = 2; // iterations in the inner loop + const int step = 8; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 const uint8_t m = 1 << (4*im); - const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 const int q_offset = 32*im + l0; const int y_offset = 128*im + l0; @@ -113,7 +104,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const uint16_t s_shift = 4*im; - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; @@ -163,14 +154,14 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + const int tid = threadIdx.x/2; // 0...16 + const int ix = threadIdx.x%2; // 0,1 - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + const int step = 4; - const int il = tid/step; // 0...3 - const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 4; const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int in = il%2; @@ -182,17 +173,12 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; -#if K_QUANTS_PER_ITERATION == 2 uint32_t q32[4]; const uint8_t * q4 = (const uint8_t *)q32; -#else - uint16_t q16[4]; - const uint8_t * q4 = (const uint8_t *)q16; -#endif float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; @@ -206,7 +192,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); -#if K_QUANTS_PER_ITERATION == 2 const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); const uint32_t * q2 = q1 + 16; @@ -223,25 +208,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; } tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#else - const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); - const uint16_t * q2 = q1 + 32; - - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[0] & 0xf0f0; - q16[2] = q2[0] & 0x0f0f; - q16[3] = q2[0] & 0xf0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 2; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; - s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#endif - } // sum up partial sums and write back result @@ -341,9 +307,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, } static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -352,21 +315,17 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const block_q6_K * x = (const block_q6_K *)vx + ib0; - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const int tid = threadIdx.x/2; // 0...16 + const int ix = threadIdx.x%2; // 0, 1 - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const int step = 8; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 -#if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 - const int is = 0; -#else - const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int l0 = 4 * in; // 0, 4, 8, ..., 28 const int is = in / 4; -#endif + const int ql_offset = 64*im + l0; const int qh_offset = 32*im + l0; const int s_offset = 8*im + is; @@ -374,7 +333,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * ql = x[i].ql + ql_offset; @@ -383,17 +342,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float d = x[i].d; -#if K_QUANTS_PER_ITERATION == 1 - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) - + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) - + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) - + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) - +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); - tmp += sum; -#else float sum = 0; for (int l = 0; l < 4; ++l) { sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) @@ -402,8 +350,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); } tmp += sum; -#endif - } // sum up partial sums and write back result @@ -547,7 +493,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 + const int ny = 2; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); @@ -556,7 +502,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); @@ -565,7 +511,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); @@ -580,7 +526,7 @@ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 5c343822f..7b1b317fa 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -124,9 +124,6 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, float *__restrict__ dst, const int ncols, int nrows, const sycl::nd_item<3> &item_ct1) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); if (row > nrows) return; @@ -140,16 +137,16 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, #if QK_K == 256 const int tid = - item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15 + item_ct1.get_local_id(2) / 2; // 0...15 const int ix = - item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + item_ct1.get_local_id(2) % 2; // 0,1 - const int step = 16/K_QUANTS_PER_ITERATION; + const int step = 8; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int l0 = 2*in; // 0...14 in steps of 2 const int q_offset = 32*im + l0; const int s_offset = 8*im; const int y_offset = 128*im + l0; @@ -158,7 +155,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, const uint8_t * d = (const uint8_t *)aux; const uint8_t * m = (const uint8_t *)(aux + 2); - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; @@ -173,7 +170,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, aux[3] = (a[1] >> 4) & 0x0f0f0f0f; float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) @@ -190,18 +187,15 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, } #else - const int tid = item_ct1.get_local_id(2) / - (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = item_ct1.get_local_id(2) % - (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; + const int tid = item_ct1.get_local_id(2) / 4; // 0...7 + const int ix = item_ct1.get_local_id(2) % 4; // 0...3 + const int offset = tid * 2; uint32_t uaux[2]; const uint8_t * d = (const uint8_t *)uaux; - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - + for (int i = ix; i < num_blocks_per_row; i += 4) { const float * y = yy + i * QK_K + offset; const uint8_t * q = x[i].qs + offset; const uint32_t * s = (const uint32_t *)x[i].scales; @@ -213,7 +207,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, x[i].dm.convert(); float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { const uint8_t ql = q[l]; sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + y[l+16] * d[1] * ((ql >> 2) & 3) @@ -268,14 +262,14 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, const uint16_t kmask2 = 0x0f0f; const int tid = - item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + item_ct1.get_local_id(2) / 2; // 0...16 const int ix = - item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + item_ct1.get_local_id(2) % 2; // 0,1 - const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop - const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0....15 or 0...7 + const int n = 2; // iterations in the inner loop + const int step = 8; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 const uint8_t m = 1 << (4*im); @@ -288,7 +282,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, const uint16_t s_shift = 4*im; - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; @@ -318,13 +312,13 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, } #else - const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 - const int in = offset/8; // 0 or 1 - const int im = offset%8; // 0...7 + const int tid = item_ct1.get_local_id(2)/4; // 0...7 + const int ix = item_ct1.get_local_id(2)%4; // 0...3 + const int offset = tid * 2; // 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 4) { const float * y = yy + i * QK_K + offset; const uint8_t * q = x[i].qs + offset; @@ -333,7 +327,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, const float dall = (float)x[i].d; float sum = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { const uint8_t hl = x[i].hmask[im+l] >> in; const uint8_t ql = q[l]; sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) @@ -384,15 +378,15 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, const uint16_t kmask3 = 0xc0c0; const int tid = - item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + item_ct1.get_local_id(2) / 2; // 0...16 const int ix = - item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + item_ct1.get_local_id(2) % 2; // 0,1 - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + const int step = 4; const int il = tid/step; // 0...3 const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + const int n = 4; const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int in = il%2; @@ -404,17 +398,12 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; -#if K_QUANTS_PER_ITERATION == 2 uint32_t q32[4]; const uint8_t * q4 = (const uint8_t *)q32; -#else - uint16_t q16[4]; - const uint8_t * q4 = (const uint8_t *)q16; -#endif float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; @@ -428,7 +417,6 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); -#if K_QUANTS_PER_ITERATION == 2 const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); const uint32_t * q2 = q1 + 16; @@ -447,38 +435,19 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f + s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) - dmin * smin; -#else - const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); - const uint16_t * q2 = q1 + 32; - - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[0] & 0xf0f0; - q16[2] = q2[0] & 0x0f0f; - q16[3] = q2[0] & 0xf0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 2; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; - s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#endif - } #else - const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); + const int tid = item_ct1.get_local_id(2)/4; // 0...15 + const int ix = item_ct1.get_local_id(2)%4; - const int step = tid * K_QUANTS_PER_ITERATION; + const int step = tid * 2; uint16_t aux16[2]; const uint8_t * s = (const uint8_t *)aux16; float tmp = 0; - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 4) { const uint8_t * q = x[i].qs + step; const float * y = yy + i*QK_K + step; const uint16_t * a = (const uint16_t *)x[i].scales; @@ -487,7 +456,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, const float d = (float)x[i].dm[0]; const float m = (float)x[i].dm[1]; float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + for (int j = 0; j < 2; ++j) { sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) @@ -609,19 +578,19 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx, } #else - const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); - const int step = tid * K_QUANTS_PER_ITERATION; + const int tid = item_ct1.get_local_id(2)/4; // 0...15 + const int ix = item_ct1.get_local_id(2)%4; + const int step = tid * 2; const int im = step/8; const int in = step%8; - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 4) { const uint8_t * q = x[i].qs + step; const int8_t * s = x[i].scales; const float * y = yy + i*QK_K + step; const float d = x[i].d; float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + for (int j = 0; j < 2; ++j) { const uint8_t h = x[i].qh[in+j] >> im; sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) @@ -646,9 +615,6 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx, static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, const sycl::nd_item<3> &item_ct1) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); if (row > nrows) return; @@ -661,22 +627,18 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa #if QK_K == 256 const int tid = - item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + item_ct1.get_local_id(2) / 2; // 0...16 const int ix = - item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1 + item_ct1.get_local_id(2) % 2; // 0, 1 - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const int step = 8; const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int in = tid - step*im; // 0...15 or 0...7 -#if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 - const int is = 0; -#else const int l0 = 4 * in; // 0, 4, 8, ..., 28 const int is = in / 4; -#endif + const int ql_offset = 64*im + l0; const int qh_offset = 32*im + l0; const int s_offset = 8*im + is; @@ -684,7 +646,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * ql = x[i].ql + ql_offset; @@ -693,17 +655,6 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa const float d = x[i].d; -#if K_QUANTS_PER_ITERATION == 1 - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) - + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) - + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) - + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) - +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); - tmp += sum; -#else float sum = 0; for (int l = 0; l < 4; ++l) { sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) @@ -712,20 +663,18 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); } tmp += sum; -#endif - } #else - const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 - const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 + const int tid = item_ct1.get_local_id(2)/4; // 0...7 + const int ix = item_ct1.get_local_id(2)%4; // 0...3 - const int step = tid * K_QUANTS_PER_ITERATION; + const int step = tid * 2; float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 4) { const float * y = yy + i * QK_K + step; const uint8_t * ql = x[i].ql + step; @@ -735,7 +684,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa const float d = x[i+0].d; float sum = 0; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + for (int j = 0; j < 2; ++j) { sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) @@ -871,7 +820,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 + const int ny = 2; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); @@ -887,7 +836,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); @@ -903,7 +852,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); @@ -932,7 +881,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp index 340ab8e93..a3a584108 100644 --- a/ggml/src/ggml-sycl/presets.hpp +++ b/ggml/src/ggml-sycl/presets.hpp @@ -52,12 +52,6 @@ #define GGML_SYCL_MMV_Y 1 #endif -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 2 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - #ifndef GGML_SYCL_PEER_MAX_BATCH_SIZE #define GGML_SYCL_PEER_MAX_BATCH_SIZE 128 #endif // GGML_SYCL_PEER_MAX_BATCH_SIZE diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 32fda32a8..9814e82b0 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -41,12 +41,6 @@ #define MAX_VK_BUFFERS 256 -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 1 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - #define VK_CHECK(err, msg) \ do { \ vk::Result err_ = (err); \ diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp index 5920bc936..cac4250fc 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp @@ -2,8 +2,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require -#define K_QUANTS_PER_ITERATION 2 - #ifdef MUL_MAT_ID #define EXPERT_COUNT 8 #endif diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp index ec8eadcd5..614f2bf8b 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -15,22 +15,22 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/2; // 0...16 + const uint ix = gl_LocalInvocationID.x%2; // 0, 1 - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 8; const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = tid - step*v_im; // 0...15 or 0...7 - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 + const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; const uint s_offset = 8*v_im; const uint y_offset = 128*v_im + l0; tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); @@ -38,7 +38,7 @@ void main() { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3), diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp index 3ca4ad85a..5d37e24b7 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -15,17 +15,17 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/2; // 0...16 + const uint ix = gl_LocalInvocationID.x%2; // 0, 1 - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 8; const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = tid - step*v_im; // 0...15 or 0...7 const uint8_t m = uint8_t(1 << (4 * v_im)); - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 + const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; @@ -33,13 +33,13 @@ void main() { const uint s_shift = 4 * v_im; - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); FLOAT_TYPE sum = FLOAT_TYPE(0.0); - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp index d91e00e10..36b6e9460 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -15,14 +15,14 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/2; // 0...16 + const uint ix = gl_LocalInvocationID.x%2; // 0, 1 - const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + const uint step = 4; const uint il = tid/step; // 0...3 const uint ir = tid - step*il; // 0...7 or 0...3 - const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + const uint n = 4; const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; @@ -33,7 +33,7 @@ void main() { tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { const uint y1_idx = i * QUANT_K + y_offset; const uint y2_idx = y1_idx + 128; @@ -49,7 +49,6 @@ void main() { const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); -#if K_QUANTS_PER_ITERATION == 2 const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf); @@ -78,30 +77,6 @@ void main() { fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7))))))))))))))); const uint tmp_idx = 16 * ix + tid; tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx])); -#else - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7, - + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7))))))); - - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + - sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f), - fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx])); -#endif } // sum up partial sums and write back result diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp index 95c286eeb..122a82f1f 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -15,21 +15,16 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/2; // 0...16 + const uint ix = gl_LocalInvocationID.x%2; // 0, 1 - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 8; const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = tid - step*v_im; // 0...15 or 0...7 -#if K_QUANTS_PER_ITERATION == 1 - const uint l0 = v_in; // 0...15 - const uint is = 0; -#else const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint is = v_in / 4; -#endif const uint ql_offset = 64*v_im + l0; const uint qh_offset = 32*v_im + l0; @@ -38,22 +33,11 @@ void main() { tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); -#if K_QUANTS_PER_ITERATION == 1 - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx])))))))); -#else FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 4; ++l) { sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32), @@ -62,7 +46,6 @@ void main() { fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum)))); } tmp[16 * ix + tid] += sum; -#endif } // sum up partial sums and write back result From ccb45186d092118c8cd4d01e7dff0cd2bcb45cce Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 26 Aug 2024 09:52:02 +0300 Subject: [PATCH 2/2] docs : remove references --- docs/build.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/build.md b/docs/build.md index 152d46d6f..588568d16 100644 --- a/docs/build.md +++ b/docs/build.md @@ -192,7 +192,6 @@ The following compilation options are also available to tweak performance: | GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | | GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | -| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | | GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. | @@ -266,7 +265,6 @@ The following compilation options are also available to tweak performance (yes, |------------------------|------------------------|---------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | GGML_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the HIP dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | | GGML_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the HIP mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | -| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per HIP thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | ### Vulkan