ggml : SOTA 2-bit quants (add IQ2_XS) (#4856)
* iq2_xs: basics * iq2_xs: this should have been in the basics * iq2_xs: CUDA and scalar CPU works * iq2_xs: WIP Metal * iq2_xs: Metal now works * iq2_xs: working, but dog slow, ARM_NEON dot product * iq2_xs: better ARM_NEON dot product We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when running on the CPU. * iq2_xs: AVX2 dot product - 19.5 t/s * iq2_xs: faster AVX2 dit product 21.4 t/s for TG-128, 59.2 t/s for PP-512. The latter is 2x compared to the previous version. * iq2_xs: had forgotten to delete iq2-data.h * Add llama enum for IQ2_XS --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
3ba5b8ca8e
commit
49662cbed3
10 changed files with 1038 additions and 28 deletions
42
ggml-metal.m
42
ggml-metal.m
|
@ -89,6 +89,7 @@ struct ggml_metal_context {
|
|||
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_i32);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
|
||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||
GGML_METAL_DECL_KERNEL(group_norm);
|
||||
GGML_METAL_DECL_KERNEL(norm);
|
||||
|
@ -108,6 +109,7 @@ struct ggml_metal_context {
|
|||
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
|
||||
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
|
||||
|
@ -124,6 +126,7 @@ struct ggml_metal_context {
|
|||
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
||||
|
@ -137,6 +140,7 @@ struct ggml_metal_context {
|
|||
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
|
||||
|
@ -150,6 +154,7 @@ struct ggml_metal_context {
|
|||
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
|
||||
GGML_METAL_DECL_KERNEL(rope_f32);
|
||||
GGML_METAL_DECL_KERNEL(rope_f16);
|
||||
GGML_METAL_DECL_KERNEL(alibi_f32);
|
||||
|
@ -385,6 +390,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_i32);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
|
||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||
GGML_METAL_ADD_KERNEL(group_norm);
|
||||
GGML_METAL_ADD_KERNEL(norm);
|
||||
|
@ -404,6 +410,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
|
||||
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
|
||||
|
@ -420,6 +427,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
||||
|
@ -434,6 +442,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
|
||||
|
@ -447,6 +456,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
|
||||
}
|
||||
GGML_METAL_ADD_KERNEL(rope_f32);
|
||||
GGML_METAL_ADD_KERNEL(rope_f16);
|
||||
|
@ -513,6 +523,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_i32);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
|
||||
GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
|
||||
GGML_METAL_DEL_KERNEL(rms_norm);
|
||||
GGML_METAL_DEL_KERNEL(group_norm);
|
||||
GGML_METAL_DEL_KERNEL(norm);
|
||||
|
@ -532,6 +543,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
|
||||
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
|
||||
|
@ -548,6 +560,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
||||
|
@ -562,6 +575,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
|
||||
|
@ -575,6 +589,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
|
||||
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
|
||||
}
|
||||
GGML_METAL_DEL_KERNEL(rope_f32);
|
||||
GGML_METAL_DEL_KERNEL(rope_f16);
|
||||
|
@ -1561,6 +1576,7 @@ bool ggml_metal_graph_compute(
|
|||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
|
||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
|
||||
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
|
||||
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
|
||||
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
||||
}
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
|
@ -1679,6 +1695,12 @@ bool ggml_metal_graph_compute(
|
|||
nth1 = 16;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
||||
|
@ -1712,12 +1734,12 @@ bool ggml_metal_graph_compute(
|
|||
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||
//src0t == GGML_TYPE_IQ2_XXS ||
|
||||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_IQ2_XXS) {
|
||||
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
|
||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q4_K) {
|
||||
|
@ -1810,6 +1832,7 @@ bool ggml_metal_graph_compute(
|
|||
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
|
||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
|
||||
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
|
||||
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
|
||||
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
||||
}
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
|
@ -1931,6 +1954,12 @@ bool ggml_metal_graph_compute(
|
|||
nth1 = 16;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
|
||||
} break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
{
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
||||
|
@ -1980,12 +2009,12 @@ bool ggml_metal_graph_compute(
|
|||
|
||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
||||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
||||
//src2t == GGML_TYPE_IQ2_XXS ||
|
||||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ2_XXS) {
|
||||
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
|
||||
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
||||
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q4_K) {
|
||||
|
@ -2026,6 +2055,7 @@ bool ggml_metal_graph_compute(
|
|||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
||||
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
|
||||
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
|
||||
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue