SOTA 2-bit quants (#4773)

* iq2_xxs: basics

* iq2_xxs: scalar and AVX2 dot products

Needed to change Q8_K to have quants in the -127...127 range,
else the IQ2_XXS AVX implementation becomes very awkward.
The alternative would have been to use Q8_0 instead. Perhaps
I'll change later, for now this is what we have.

* iq2_xxs: ARM_NEON dot product

Somehow strangely slow (112 ms/token).

* iq2_xxs: WIP Metal

Dequantize works, something is still wrong with the
dot product.

* iq2_xxs: Metal dot product now works

We have
PP-512 = 475 t/s
TG-128 = 47.3 t/s

Not the greatest performance, but not complete garbage either.

* iq2_xxs: slighty faster dot product

TG-128 is now 48.4 t/s

* iq2_xxs: slighty faster dot product

TG-128 is now 50.9 t/s

* iq2_xxs: even faster Metal dot product

TG-128 is now 54.1 t/s.

Strangely enough, putting the signs lookup table
into shared memory has a bigger impact than the
grid values being in shared memory.

* iq2_xxs: dequantize CUDA kernel - fix conflict with master

* iq2_xxs: quantized CUDA dot product (MMVQ)

We get TG-128 = 153.1 t/s

* iq2_xxs: slightly faster CUDA dot product

TG-128 is now at 155.1 t/s.

* iq2_xxs: add to llama ftype enum

* iq2_xxs: fix MoE on Metal

* Fix missing MMQ ops when on hipBLAS

I had put the ggml_supports_mmq call at the wrong place.

* Fix bug in qequantize_row_iq2_xxs

The 0.25f factor was missing.
Great detective work by @ggerganov!

* Fixing tests

* PR suggestion

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2024-01-08 16:02:32 +01:00 committed by GitHub
parent 668b31fc7d
commit dd5ae06405
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 902 additions and 1 deletions

26
ggml.c
View file

@ -573,6 +573,17 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot = ggml_vec_dot_q6_K_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_IQ2_XXS] = {
.type_name = "iq2_xxs",
.blck_size = QK_K,
.type_size = sizeof(block_iq2_xxs),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
.from_float = quantize_row_iq2_xxs,
.from_float_reference = (ggml_from_float_t) quantize_row_iq2_xxs_reference,
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
.vec_dot_type = GGML_TYPE_Q8_K,
},
[GGML_TYPE_Q8_K] = {
.type_name = "q8_K",
.blck_size = QK_K,
@ -2111,6 +2122,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break;
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
}
@ -7436,6 +7448,7 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
} break;
@ -7700,6 +7713,7 @@ static void ggml_compute_forward_add1(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
} break;
@ -7814,6 +7828,7 @@ static void ggml_compute_forward_acc(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
default:
{
GGML_ASSERT(false);
@ -10455,6 +10470,7 @@ static void ggml_compute_forward_out_prod(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst);
} break;
@ -10629,6 +10645,7 @@ static void ggml_compute_forward_set(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
default:
{
GGML_ASSERT(false);
@ -10823,6 +10840,7 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
{
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
@ -11459,6 +11477,7 @@ static void ggml_compute_forward_alibi(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_Q8_K:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
@ -11533,6 +11552,7 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_Q8_K:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
@ -18648,6 +18668,12 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
block_q6_K * block = (block_q6_K*)dst + start / QK_K;
result = ggml_quantize_q6_K(src + start, block, n, n, hist);
} break;
case GGML_TYPE_IQ2_XXS:
{
GGML_ASSERT(start % QK_K == 0);
block_iq2_xxs * block = (block_iq2_xxs*)dst + start / QK_K;
result = ggml_quantize_iq2_xxs(src + start, block, n, n, hist);
} break;
case GGML_TYPE_F16:
{
int elemsize = sizeof(ggml_fp16_t);