iq2_xs: WIP Metal

This commit is contained in:
Iwan Kawrakow 2024-01-09 17:46:27 +01:00
parent 9b6e38d8c0
commit 0aacd55159
2 changed files with 275 additions and 6 deletions

View file

@ -89,6 +89,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(get_rows_i32);
GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(group_norm);
GGML_METAL_DECL_KERNEL(norm);
@ -108,6 +109,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
@ -124,6 +126,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@ -137,6 +140,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
@ -150,6 +154,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
GGML_METAL_DECL_KERNEL(rope_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
@ -385,6 +390,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(get_rows_i32);
GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(group_norm);
GGML_METAL_ADD_KERNEL(norm);
@ -404,6 +410,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
@ -420,6 +427,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
@ -434,6 +442,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
@ -447,6 +456,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
}
GGML_METAL_ADD_KERNEL(rope_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
@ -513,6 +523,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(get_rows_i32);
GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(group_norm);
GGML_METAL_DEL_KERNEL(norm);
@ -532,6 +543,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
//GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
@ -548,6 +560,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
@ -562,6 +575,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
@ -575,6 +589,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
}
GGML_METAL_DEL_KERNEL(rope_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
@ -1557,6 +1572,7 @@ bool ggml_metal_graph_compute(
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1675,6 +1691,12 @@ bool ggml_metal_graph_compute(
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
} break;
case GGML_TYPE_IQ2_XS:
{
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
} break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
@ -1708,12 +1730,12 @@ bool ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
//src0t == GGML_TYPE_IQ2_XXS ||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS) {
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q4_K) {
@ -1806,6 +1828,7 @@ bool ggml_metal_graph_compute(
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -1927,6 +1950,12 @@ bool ggml_metal_graph_compute(
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
} break;
case GGML_TYPE_IQ2_XS:
{
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
} break;
default:
{
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
@ -1976,12 +2005,12 @@ bool ggml_metal_graph_compute(
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
//src2t == GGML_TYPE_IQ2_XXS ||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_IQ2_XXS) {
[encoder setThreadgroupMemoryLength:(256*8+128) atIndex:0];
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src2t == GGML_TYPE_Q4_K) {
@ -2022,6 +2051,7 @@ bool ggml_metal_graph_compute(
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
case GGML_TYPE_I32: [encoder setComputePipelineState:ctx->pipeline_get_rows_i32]; break;
case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
default: GGML_ASSERT(false && "not implemented");
}