From a112eb45c4584328c4a47f00e3369ae309147b64 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 18:29:09 +0200 Subject: [PATCH] ggml : add ggml-metal-impl.h ggml-ci --- Makefile | 5 +- ggml/src/ggml-common.h | 240 ------------------------- ggml/src/ggml-metal/CMakeLists.txt | 18 +- ggml/src/ggml-metal/ggml-metal-impl.h | 249 ++++++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal.m | 5 +- ggml/src/ggml-metal/ggml-metal.metal | 2 +- 6 files changed, 266 insertions(+), 253 deletions(-) create mode 100644 ggml/src/ggml-metal/ggml-metal-impl.h diff --git a/Makefile b/Makefile index fecf1f693..647da232b 100644 --- a/Makefile +++ b/Makefile @@ -963,6 +963,7 @@ endif # GGML_METAL ifdef GGML_METAL ggml/src/ggml-metal/ggml-metal.o: \ ggml/src/ggml-metal/ggml-metal.m \ + ggml/src/ggml-metal/ggml-metal-impl.h \ ggml/include/ggml-metal.h \ ggml/include/ggml.h $(CC) $(CFLAGS) -c $< -o $@ @@ -970,9 +971,11 @@ ggml/src/ggml-metal/ggml-metal.o: \ ifdef GGML_METAL_EMBED_LIBRARY ggml/src/ggml-metal-embed.o: \ ggml/src/ggml-metal/ggml-metal.metal \ + ggml/src/ggml-metal/ggml-metal-impl.h \ ggml/src/ggml-common.h @echo "Embedding Metal library" - @sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal + @sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp + @sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal $(eval TEMP_ASSEMBLY=$(shell mktemp -d)) @echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s @echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index d25100693..050161393 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -418,246 +418,6 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); -#if defined(GGML_COMMON_DECL_METAL_KARGS) -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - int32_t dim; -} ggml_metal_kargs_concat; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - uint64_t offs; -} ggml_metal_kargs_bin; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; -} ggml_metal_kargs_repeat; - -typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int64_t ne0; - int64_t ne1; - int64_t ne2; - int64_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; -} ggml_metal_kargs_cpy; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - int32_t n_past; - int32_t n_dims; - int32_t n_ctx_orig; - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; -} ggml_metal_kargs_rope; - -typedef struct { - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne11; - int32_t ne_12_2; // assume K and V are same shape - int32_t ne_12_3; - uint64_t nb_12_1; - uint64_t nb_12_2; - uint64_t nb_12_3; - uint64_t nb31; - int32_t ne1; - int32_t ne2; - float scale; - float max_bias; - float m0; - float m1; - uint16_t n_head_log2; - float logit_softcap; -} ggml_metal_kargs_flash_attn_ext; - -typedef struct { - int32_t ne00; - int32_t ne02; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne12; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int16_t r2; - int16_t r3; -} ggml_metal_kargs_mul_mm; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int16_t r2; - int16_t r3; -} ggml_metal_kargs_mul_mv; - -typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; - int32_t ne00; - int32_t ne02; - uint64_t nb01; - uint64_t nb02; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; -} ggml_metal_kargs_mul_mm_id; - -typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; - int32_t ne00; - int32_t ne01; - int32_t ne02; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; - uint64_t nb1; -} ggml_metal_kargs_mul_mv_id; - -typedef struct { - int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_norm; - -typedef struct { - int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_rms_norm; -#endif - #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index e0992c744..b237d79f4 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -25,9 +25,10 @@ if (GGML_METAL_USE_BF16) add_compile_definitions(GGML_METAL_USE_BF16) endif() -# copy ggml-common.h and ggml-metal.metal to bin directory -configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) -configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +# copy metal files to bin directory +configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) +configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) @@ -36,24 +37,27 @@ if (GGML_METAL_EMBED_LIBRARY) set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( OUTPUT ${METALLIB_EMBED_ASM} COMMAND echo "Embedding Metal library" - COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED} + COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP} + COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED} COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM} COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM} - DEPENDS ggml-metal.metal ../ggml-common.h + DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h COMMENT "Generate assembly for embedded Metal library" ) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h new file mode 100644 index 000000000..53c135496 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -0,0 +1,249 @@ +#ifndef GGML_METAL_IMPL +#define GGML_METAL_IMPL + +// kernel argument structs +// +// - element counters (e.g. ne00) typically use int32_t to reduce register usage +// however, be careful from int overflows when using those in the kernel implementation +// +// - strides (e.g. nb00) use uint64_t + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; + +#endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index b683d5431..58fee4bfd 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2,10 +2,7 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" - -#define GGML_COMMON_DECL_C -#define GGML_COMMON_DECL_METAL_KARGS -#include "ggml-common.h" +#import "ggml-metal-impl.h" #import diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 6bdf4e4cc..86fdf1c18 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1,5 +1,4 @@ #define GGML_COMMON_DECL_METAL -#define GGML_COMMON_DECL_METAL_KARGS #define GGML_COMMON_IMPL_METAL #if defined(GGML_METAL_EMBED_LIBRARY) __embed_ggml-common.h__ @@ -7,6 +6,7 @@ __embed_ggml-common.h__ // TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift #include "../ggml-common.h" #endif +#include "ggml-metal-impl.h" #include