From ad1226982fd52181d1e7cb08183159e559c5328d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 6 Nov 2024 19:02:28 +0200 Subject: [PATCH] metal : do not build bfloat kernels when not supported ggml-ci --- ggml/src/ggml-metal.m | 6 +++++- ggml/src/ggml-metal.metal | 26 ++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index bfe3725d5..6cbe88170 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -496,7 +496,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de // dictionary of preprocessor macros NSMutableDictionary * prep = [NSMutableDictionary dictionary]; - MTLCompileOptions* options = [MTLCompileOptions new]; + if (!ctx_dev->has_bfloat) { + [prep setObject:@"GGML_METAL_NO_BFLOAT" forKey:@"GGML_METAL_NO_BFLOAT"]; + } + + MTLCompileOptions * options = [MTLCompileOptions new]; options.preprocessorMacros = prep; //[options setFastMathEnabled:false]; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2cfaf98b5..accdaaa6d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -12,12 +12,14 @@ using namespace metal; #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 +#if !defined(GGML_METAL_NO_BFLOAT) +typedef matrix bfloat4x4; +#endif + constexpr constant static float kvalues_iq4nl_f[16] = { -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f }; -typedef matrix bfloat4x4; - // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -29,10 +31,12 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) reg = (type4x4)(*src); } +#if !defined(GGML_METAL_NO_BFLOAT) template void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { reg = (type4x4)(*src); } +#endif template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { @@ -2048,8 +2052,10 @@ typedef decltype(kernel_mul_mv) mul_mv_t; template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv; template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv; template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv; +#if !defined(GGML_METAL_NO_BFLOAT) template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv; template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv; +#endif template kernel void kernel_mul_mv_1row( @@ -2119,7 +2125,9 @@ kernel void kernel_mul_mv_1row( typedef decltype(kernel_mul_mv_1row) mul_mv_1row_t; template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; +#if !defined(GGML_METAL_NO_BFLOAT) template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row; +#endif // Assumes row size (ne00) is a multiple of 4 template @@ -2179,7 +2187,9 @@ kernel void kernel_mul_mv_l4( typedef decltype(kernel_mul_mv_l4) mul_mv_l4_t; template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +#if !defined(GGML_METAL_NO_BFLOAT) template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4; +#endif static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / max(0.001f, high - low); @@ -3578,11 +3588,15 @@ typedef decltype(kernel_cpy) kernel_cpy_t; template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; +#if !defined(GGML_METAL_NO_BFLOAT) template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +#endif template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +#if !defined(GGML_METAL_NO_BFLOAT) template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +#endif kernel void kernel_cpy_f32_q8_0( device const float * src0, @@ -6487,7 +6501,9 @@ typedef decltype(kernel_get_rows_f) get_rows_f_t; template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +#if !defined(GGML_METAL_NO_BFLOAT) template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +#endif typedef decltype(kernel_get_rows_q) get_rows_q_t; @@ -6519,7 +6535,9 @@ typedef decltype(kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +#if !defined(GGML_METAL_NO_BFLOAT) template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm; +#endif template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; @@ -6548,7 +6566,9 @@ typedef decltype(kernel_mul_mm_id) mat_mm_id_t; template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#if !defined(GGML_METAL_NO_BFLOAT) template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +#endif template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; @@ -6772,7 +6792,9 @@ typedef decltype(kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#if !defined(GGML_METAL_NO_BFLOAT) template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +#endif template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;