metal: template for mat-vec multiplication kernels

This commit is contained in:
lshzh-ww 2023-08-29 23:24:19 -04:00
parent 06abf8eeba
commit bca5d0c2de
2 changed files with 674 additions and 1321 deletions

View file

@ -76,14 +76,15 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
@ -205,14 +206,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
@ -270,14 +272,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(rms_norm); GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
@ -832,97 +835,42 @@ void ggml_metal_graph_compute(
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else { } else if (ggml_is_contiguous(src0) &&
int nth0 = 32; ggml_is_contiguous(src1) &&
int nth1 = 1; src1t == GGML_TYPE_F32 &&
ne00%32 == 0) {
// use custom matrix x vector kernel // use custom matrix x vector kernel
switch (src0t) { switch (src0->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; break;
{ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; break;
nth0 = 64; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; break;
nth1 = 1; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; break;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; break;
} break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; break;
{ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; break;
GGML_ASSERT(ne02 == 1); case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; break;
GGML_ASSERT(ne12 == 1); default: GGML_ASSERT(false && "MUL MAT-VEC not implemented");
}
nth0 = 8; int buffer_size_aligned = (512 / ggml_blck_size(src0t) * ggml_element_size(src0) + 31) / 32 * 32;
nth1 = 8; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
} break; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
case GGML_TYPE_Q4_1: [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
{ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
GGML_ASSERT(ne02 == 1); [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
GGML_ASSERT(ne12 == 1); [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
nth0 = 8; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
nth1 = 8; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
} break; [encoder setThreadgroupMemoryLength:8 * buffer_size_aligned atIndex:0];
case GGML_TYPE_Q8_0: [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
{ } else {
GGML_ASSERT(ne02 == 1); switch (src0->type) {
GGML_ASSERT(ne12 == 1); case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; break;
default: GGML_ASSERT(false && " not implemented");
nth0 = 8; }
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
} break;
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
} break;
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
} break;
case GGML_TYPE_Q4_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
} break;
case GGML_TYPE_Q6_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
} break;
default:
{
metal_printf("Asserting on type %d\n",(int)src0t);
GGML_ASSERT(false && "not implemented");
}
};
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -941,27 +889,8 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
[encoder setThreadgroupMemoryLength:64*sizeof(float) atIndex:0];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
else if (src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
} }
} break; } break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:

File diff suppressed because it is too large Load diff