From 8b98d01e31daf2ad32671ba9446d858e5c4769de Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 14:16:12 +0300 Subject: [PATCH] k_quants: call them _K, not _k, also on Metal --- ggml-metal.m | 60 +++++++++++++++++++-------------------- ggml-metal.metal | 74 ++++++++++++++++++++++++------------------------ 2 files changed, 67 insertions(+), 67 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index f84a2c433..7551231b9 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -51,21 +51,21 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); - GGML_METAL_DECL_KERNEL(get_rows_q2_k); - GGML_METAL_DECL_KERNEL(get_rows_q3_k); - GGML_METAL_DECL_KERNEL(get_rows_q4_k); - GGML_METAL_DECL_KERNEL(get_rows_q5_k); - GGML_METAL_DECL_KERNEL(get_rows_q6_k); + GGML_METAL_DECL_KERNEL(get_rows_q2_K); + GGML_METAL_DECL_KERNEL(get_rows_q3_K); + GGML_METAL_DECL_KERNEL(get_rows_q4_K); + GGML_METAL_DECL_KERNEL(get_rows_q5_K); + GGML_METAL_DECL_KERNEL(get_rows_q6_K); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); @@ -165,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); - GGML_METAL_ADD_KERNEL(get_rows_q2_k); - GGML_METAL_ADD_KERNEL(get_rows_q3_k); - GGML_METAL_ADD_KERNEL(get_rows_q4_k); - GGML_METAL_ADD_KERNEL(get_rows_q5_k); - GGML_METAL_ADD_KERNEL(get_rows_q6_k); + GGML_METAL_ADD_KERNEL(get_rows_q2_K); + GGML_METAL_ADD_KERNEL(get_rows_q3_K); + GGML_METAL_ADD_KERNEL(get_rows_q4_K); + GGML_METAL_ADD_KERNEL(get_rows_q5_K); + GGML_METAL_ADD_KERNEL(get_rows_q6_K); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); @@ -668,7 +668,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; } break; case GGML_TYPE_Q3_K: { @@ -677,7 +677,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; } break; case GGML_TYPE_Q4_K: { @@ -686,7 +686,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; } break; case GGML_TYPE_Q5_K: { @@ -695,7 +695,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; } break; case GGML_TYPE_Q6_K: { @@ -704,7 +704,7 @@ void ggml_metal_graph_compute( nth0 = 4; nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; } break; default: { @@ -756,11 +756,11 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; - case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; - case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 42c3a0412..3b4eac2bf 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -792,7 +792,7 @@ typedef struct { uint8_t qs[QK_K/4]; // quants half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins -} block_q2_k; +} block_q2_K; // 84 bytes / block typedef struct { @@ -804,20 +804,20 @@ typedef struct { uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits #endif half d; // super-block scale -} block_q3_k; +} block_q3_K; #if QK_K == 64 typedef struct { half4 d; // super-block scales/mins uint8_t qs[QK_K/2]; // 4-bit quants -} block_q4_k; +} block_q4_K; #else typedef struct { half d; // super-block scale for quantized scales half dmin; // super-block scale for quantized mins uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_k; +} block_q4_K; #endif #if QK_K == 64 @@ -825,7 +825,7 @@ typedef struct { half4 d; // super-block scales/mins uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_k; +} block_q5_K; #else typedef struct { half d; // super-block scale for quantized scales @@ -833,7 +833,7 @@ typedef struct { uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits -} block_q5_k; +} block_q5_K; // 176 bytes / block #endif @@ -842,7 +842,7 @@ typedef struct { uint8_t qh[QK_K/4]; // quants, upper 2 bits int8_t scales[QK_K/16]; // scales, quantized with 8 bits half d; // super-block scale -} block_q6_k; +} block_q6_K; // 210 bytes / block static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { @@ -863,7 +863,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { //========================================== dequantization ============================= -static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) { +static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -910,7 +910,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i } } -static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) { +static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -992,7 +992,7 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i } -static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { +static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1028,7 +1028,7 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i } } -static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) { +static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1077,7 +1077,7 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i } -static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { +static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) { assert(k % QK_K == 0); const int nb = k / QK_K; @@ -1123,7 +1123,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i } } -kernel void kernel_get_rows_q2_k( +kernel void kernel_get_rows_q2_K( device const void * src0, device const int * src1, device float * dst, @@ -1134,12 +1134,12 @@ kernel void kernel_get_rows_q2_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q2_k( - (device const block_q2_k *) ((device char *) src0 + r*nb01), + dequantize_row_q2_K( + (device const block_q2_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q3_k( +kernel void kernel_get_rows_q3_K( device const void * src0, device const int * src1, device float * dst, @@ -1150,12 +1150,12 @@ kernel void kernel_get_rows_q3_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q3_k( - (device const block_q3_k *) ((device char *) src0 + r*nb01), + dequantize_row_q3_K( + (device const block_q3_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q4_k( +kernel void kernel_get_rows_q4_K( device const void * src0, device const int * src1, device float * dst, @@ -1166,12 +1166,12 @@ kernel void kernel_get_rows_q4_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q4_k( - (device const block_q4_k *) ((device char *) src0 + r*nb01), + dequantize_row_q4_K( + (device const block_q4_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q5_k( +kernel void kernel_get_rows_q5_K( device const void * src0, device const int * src1, device float * dst, @@ -1182,12 +1182,12 @@ kernel void kernel_get_rows_q5_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q5_k( - (device const block_q5_k *) ((device char *) src0 + r*nb01), + dequantize_row_q5_K( + (device const block_q5_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } -kernel void kernel_get_rows_q6_k( +kernel void kernel_get_rows_q6_K( device const void * src0, device const int * src1, device float * dst, @@ -1198,14 +1198,14 @@ kernel void kernel_get_rows_q6_k( const int i = tpig; const int r = ((device int32_t *) src1)[i]; - dequantize_row_q6_k( - (device const block_q6_k *) ((device char *) src0 + r*nb01), + dequantize_row_q6_K( + (device const block_q6_K *) ((device char *) src0 + r*nb01), (device float *) ((device char *) dst + i*nb1), ne00); } //====================================== dot products ========================= -kernel void kernel_mul_mat_q2_k_f32( +kernel void kernel_mul_mat_q2_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1222,7 +1222,7 @@ kernel void kernel_mul_mat_q2_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb; + device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; @@ -1317,7 +1317,7 @@ kernel void kernel_mul_mat_q2_k_f32( } } -kernel void kernel_mul_mat_q3_k_f32( +kernel void kernel_mul_mat_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1341,7 +1341,7 @@ kernel void kernel_mul_mat_q3_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb; + device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; @@ -1455,7 +1455,7 @@ kernel void kernel_mul_mat_q3_k_f32( } -kernel void kernel_mul_mat_q4_k_f32( +kernel void kernel_mul_mat_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1475,7 +1475,7 @@ kernel void kernel_mul_mat_q4_k_f32( const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; - device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; + device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; float sumf = 0; @@ -1580,7 +1580,7 @@ kernel void kernel_mul_mat_q4_k_f32( //} } -kernel void kernel_mul_mat_q5_k_f32( +kernel void kernel_mul_mat_q5_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1597,7 +1597,7 @@ kernel void kernel_mul_mat_q5_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb; + device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y; @@ -1704,7 +1704,7 @@ kernel void kernel_mul_mat_q5_k_f32( } -kernel void kernel_mul_mat_q6_k_f32( +kernel void kernel_mul_mat_q6_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1726,7 +1726,7 @@ kernel void kernel_mul_mat_q6_k_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; + device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb; device const float * yy = (device const float *) src1 + r1*ne10; const int nth = tptg.x*tptg.y;