diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 050bdd1de..51c44d857 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -15,13 +15,63 @@ typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, cons typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00); typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); +enum mmq_q8_1_ds_layout { + MMQ_Q8_1_DS_LAYOUT_D4, + MMQ_Q8_1_DS_LAYOUT_DS4, + MMQ_Q8_1_DS_LAYOUT_D2S6, +}; + struct block_q8_1_mmq { - half2 ds[4]; - int8_t qs[4*QK8_1]; + // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. + // The y float data is first grouped as blocks of 128 values. + // These blocks are then treated as individual data values and transposed. + // + // To avoid shared memory bank conflicts each block is padded with 16 bytes. + // This padding is also used to store block scales/partial sums. + // The scales multiplied with the quantized data are equal to the unquantized values. + // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) + // and are only needed for performance reasons. + // + // The exact data stored depends on the x data type. + union { + float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 + half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 + half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, + // stored as d0,d1,s1,s2,s3,s4,s5 + }; + int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); +static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q5_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q5_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q8_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q2_K: + return MMQ_Q8_1_DS_LAYOUT_D2S6; + case GGML_TYPE_Q3_K: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return MMQ_Q8_1_DS_LAYOUT_D4; + default: + GGML_ASSERT(false); + break; + } +} + struct tile_x_sizes { int qs; int dm; @@ -2362,35 +2412,6 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -static int mmq_need_sum(const ggml_type type_x) { - switch (type_x) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - return 1; - case GGML_TYPE_Q5_0: - return 0; - case GGML_TYPE_Q5_1: - return 1; - case GGML_TYPE_Q8_0: - return 0; - case GGML_TYPE_Q2_K: - return 2; - case GGML_TYPE_Q3_K: - return 0; - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - return 1; - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - return 0; - default: - GGML_ASSERT(false); - break; - } - return -1; -} - template static __device__ void mul_mat_q_process_tile( const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 1bc499b41..aa7f1eff0 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -37,10 +37,13 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest reinterpret_cast(y[ib].ds.y) = sum; } -template +template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { + constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; + constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; + const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; if (ix0 >= kx0_padded) { @@ -57,22 +60,26 @@ static __global__ void quantize_mmq_q8_1( const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel const int64_t iqs = ix0 % (4*QK8_1); // quant index in block + // Load 4 floats per thread and calculate max. abs. value between them: const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); float amax = fabsf(xi.x); amax = fmaxf(amax, fabsf(xi.y)); amax = fmaxf(amax, fabsf(xi.z)); amax = fmaxf(amax, fabsf(xi.w)); + // Exchange max. abs. value between vals_per_scale/4 threads. #pragma unroll - for (int mask = need_sum == 2 ? 8 : 4; mask > 0; mask >>= 1) { + for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) { amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); } float sum; - if (need_sum > 0) { + if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { sum = xi.x + xi.y + xi.z + xi.w; + + // Exchange calculate sum across vals_per_sum/4 threads. #pragma unroll - for (int mask = need_sum == 2 ? 2 : 4; mask > 0; mask >>= 1) { + for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) { sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE); } } @@ -84,36 +91,38 @@ static __global__ void quantize_mmq_q8_1( q.z = roundf(xi.z*d_inv); q.w = roundf(xi.w*d_inv); + // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: char4 * yqs4 = (char4 *) y[ib].qs; yqs4[iqs/4] = q; - if (need_sum < 2) { - if (iqs % QK8_1 != 0) { + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) { + if (iqs % 16 != 0 || iqs >= 96) { + return; + } + + y[ib].d2s6[2 + iqs/16] = sum; + + if (iqs % 64 != 0) { return; } const float d = 1.0f / d_inv; - if (need_sum > 0) { - y[ib].ds[iqs/QK8_1] = make_half2(d, sum); - } else { - ((float *) y[ib].ds)[iqs/QK8_1] = d; - } + y[ib].d2s6[iqs/64] = d; + + return; + } + + if (iqs % 32 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) { + y[ib].ds4[iqs/32] = make_half2(d, sum); } else { - if (iqs % (QK8_1/2) != 0 || iqs >= (3*QK8_1)) { - return; - } - - half * ydsh = (half *) y[ib].ds; - ydsh[2 + iqs/(QK8_1/2)] = sum; - - if (iqs % (QK8_1*2) != 0) { - return; - } - - const float d = 1.0f / d_inv; - - ydsh[iqs/(QK8_1*2)] = d; + y[ib].d4[iqs/32] = d; } } @@ -140,15 +149,18 @@ void quantize_mmq_q8_1_cuda( const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); const dim3 num_blocks(block_num_x, kx1, channels); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); - switch (mmq_need_sum(type_x)) { - case 0: - quantize_mmq_q8_1<0><<>>(x, vy, kx0, kx1, kx0_padded); + switch (mmq_get_q8_1_ds_layout(type_x)) { + case MMQ_Q8_1_DS_LAYOUT_D4: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); break; - case 1: - quantize_mmq_q8_1<1><<>>(x, vy, kx0, kx1, kx0_padded); + case MMQ_Q8_1_DS_LAYOUT_DS4: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); break; - case 2: - quantize_mmq_q8_1<2><<>>(x, vy, kx0, kx1, kx0_padded); + case MMQ_Q8_1_DS_LAYOUT_D2S6: + quantize_mmq_q8_1 + <<>>(x, vy, kx0, kx1, kx0_padded); break; default: GGML_ASSERT(false);