IQ4_NL: 4-bit non-linear quants with blocks of 32 (#5590)
* iq4_nl: squash commits for easier rebase * Basics (quantize, dequantize) * CUDA dequantize and dot product * Slightly faster CUDA dot product (120 t/s) * Switch to 6-bit scales * Scalar dot product * AVX2 dot product * ARM_NEON dot product * Works on metal, but still slow * Slightly better Metal dot product * Another small Metal improvement * Metal dot product is getting there * Faster CUDA dot product * Add 1/8 ffn_down layers as Q5_K when no imatrix has been provided * Report the actual bpw * Add _xs mix that is 4.05 bpw for non-MoE models * Remove IQ4_XS for now, slightly adjust kvalues_iq4nl * AVX2 dot product uses Q8_0 instead of Q8_K * Add to test-backend-ops * Minor fix * Also use use Q5_K for attn_output in MoE models * Fixes after merging latest master * Switching to blocks of 32 * AVX2 for blocks of 32 * Scaler dot product for blocks of 32 * ARM_NEON dot product for blocks of 32 * Metal kernels for blocks of 32 * Slightly faster Metal kernels * iq4_nl: Fix after merging with master * iq4_nl: another fix after merging with master * Use IQ4_NL instead of Q4_K when using k-quants is not possible * Fix typo that makes several tests fail * It was the ggml_vdotq thing missed inside the brackets --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
6560bed3f0
commit
a14679cc30
11 changed files with 640 additions and 7 deletions
98
ggml-cuda.cu
98
ggml-cuda.cu
|
@ -528,6 +528,15 @@ typedef struct {
|
|||
} block_iq1_s;
|
||||
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
|
||||
|
||||
#define QK4_NL 32
|
||||
#define QR4_NL 2
|
||||
#define QI4_NL (QK4_NL / (4*QR4_NL))
|
||||
typedef struct {
|
||||
half d;
|
||||
uint8_t qs[QK4_NL/2];
|
||||
} block_iq4_nl;
|
||||
static_assert(sizeof(block_iq4_nl) == sizeof(ggml_fp16_t) + QK4_NL/2, "wrong iq4_nl block size/padding");
|
||||
|
||||
#define WARP_SIZE 32
|
||||
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
|
||||
|
||||
|
@ -1987,6 +1996,26 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
|||
|
||||
}
|
||||
|
||||
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||
const float d = (float)x[ib].d;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
||||
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
|
@ -4732,6 +4761,56 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
|||
#endif
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4, const uint8_t * values,
|
||||
int & val1, int & val2) {
|
||||
|
||||
uint32_t aux32; const uint8_t * q8 = (const uint8_t *)&aux32;
|
||||
aux32 = q4 & 0x0f0f0f0f;
|
||||
uint16_t v1 = values[q8[0]] | (values[q8[1]] << 8);
|
||||
uint16_t v2 = values[q8[2]] | (values[q8[3]] << 8);
|
||||
val1 = v1 | (v2 << 16);
|
||||
aux32 = (q4 >> 4) & 0x0f0f0f0f;
|
||||
v1 = values[q8[0]] | (values[q8[1]] << 8);
|
||||
v2 = values[q8[2]] | (values[q8[3]] << 8);
|
||||
val2 = v1 | (v2 << 16);
|
||||
}
|
||||
#endif
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
|
||||
const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
|
||||
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
|
||||
const int32_t * q8 = (const int32_t *)bq8_1->qs + iqs;
|
||||
|
||||
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
|
||||
|
||||
int v1, v2;
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
||||
const uint32_t aux = q4[2*l] | (q4[2*l+1] << 16);
|
||||
get_int_from_table_16(aux, values, v1, v2);
|
||||
sumi1 = __dp4a(v1, q8[l+0], sumi1);
|
||||
sumi2 = __dp4a(v2, q8[l+4], sumi2);
|
||||
}
|
||||
|
||||
#else
|
||||
const uint8_t * q4 = bq->qs + 4*iqs;
|
||||
const int8_t * q8 = bq8_1->qs + 4*iqs;
|
||||
|
||||
int sumi1 = 0, sumi2 = 0;
|
||||
for (int l = 0; l < 4*VDR_Q4_0_Q8_1_MMVQ; ++l) {
|
||||
sumi1 += q8[l+ 0] * kvalues_iq4nl[q4[l] & 0xf];
|
||||
sumi2 += q8[l+16] * kvalues_iq4nl[q4[l] >> 4];
|
||||
}
|
||||
#endif
|
||||
const float d = (float)bq->d * __low2float(bq8_1->ds);
|
||||
return d * (sumi1 + sumi2);
|
||||
}
|
||||
|
||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||
static __device__ __forceinline__ void mul_mat_q(
|
||||
|
@ -6777,6 +6856,12 @@ static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, c
|
|||
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int nb = (k + QK_K - 1) / QK_K;
|
||||
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
|
@ -6818,6 +6903,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|||
return dequantize_row_iq3_xxs_cuda;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_cuda;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return dequantize_row_iq4_nl_cuda;
|
||||
case GGML_TYPE_F32:
|
||||
return convert_unary_cuda<float>;
|
||||
default:
|
||||
|
@ -6855,6 +6942,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
|||
return dequantize_row_iq3_xxs_cuda;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
return dequantize_row_iq1_s_cuda;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return dequantize_row_iq4_nl_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
return convert_unary_cuda<half>;
|
||||
default:
|
||||
|
@ -8599,6 +8688,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
@ -8623,6 +8713,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
|
|||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
return max_compute_capability >= CC_VOLTA ? 128 : 64;
|
||||
case GGML_TYPE_Q6_K:
|
||||
return 64;
|
||||
|
@ -8724,6 +8815,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|||
mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
|
@ -11446,7 +11541,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||
return false;
|
||||
}
|
||||
ggml_type a_type = a->type;
|
||||
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S) {
|
||||
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS ||
|
||||
a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL) {
|
||||
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
||||
return false;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue