metal : add Q2_K implementation (#1762)
* metal : add Q2_K implementation 27.1 ms / token on M2 Max 30-core GPU, so about the same speed as Q4_0. Memory throughput is ~156 GB/s. The access pattern used in the Q2_K CUDA implementation resulted in significantly lower performance (~31 ms/token). * Fixing merge conflicts --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
0bf7cf1b29
commit
72ff5282bf
2 changed files with 200 additions and 18 deletions
17
ggml-metal.m
17
ggml-metal.m
|
@ -49,11 +49,13 @@ struct ggml_metal_context {
|
|||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q2_k);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q4_k);
|
||||
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
|
||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
|
||||
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
|
||||
GGML_METAL_DECL_KERNEL(rope);
|
||||
|
@ -137,11 +139,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q2_k);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q4_k);
|
||||
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
|
||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
|
||||
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
|
||||
GGML_METAL_ADD_KERNEL(rope);
|
||||
|
@ -525,6 +529,15 @@ void ggml_metal_graph_compute(
|
|||
nth1 = 4;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
{
|
||||
GGML_ASSERT(ne02 == 1);
|
||||
GGML_ASSERT(ne12 == 1);
|
||||
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
{
|
||||
GGML_ASSERT(ne02 == 1);
|
||||
|
@ -570,6 +583,9 @@ void ggml_metal_graph_compute(
|
|||
if (src0t == GGML_TYPE_Q4_0) {
|
||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else if (src0t == GGML_TYPE_Q2_K) {
|
||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else if (src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
|
@ -591,6 +607,7 @@ void ggml_metal_graph_compute(
|
|||
switch (src0->type) {
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
|
||||
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
|
||||
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue