diff --git a/examples/mtl/CMakeLists.txt b/examples/mtl/CMakeLists.txt index c532a5582..1de83a1b6 100644 --- a/examples/mtl/CMakeLists.txt +++ b/examples/mtl/CMakeLists.txt @@ -2,9 +2,9 @@ if (APPLE) # # mtl - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) set(TEST_TARGET mtl) @@ -16,5 +16,18 @@ if (APPLE) ${METALKIT_FRAMEWORK} ${METALPERFORMANCE_FRAMEWORK} ) + + # TODO: temporary until the kernels are ready + # custom command to build mtl.metal into a library + # depends on the mtl.metal file + add_custom_target(mtl.metallib-tmp ALL DEPENDS ${CMAKE_BINARY_DIR}/mtl.metallib) + + add_custom_command( + OUTPUT ${CMAKE_BINARY_DIR}/mtl.metallib + COMMAND xcrun -sdk macosx metal -c ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal -o ${CMAKE_BINARY_DIR}/mtl.air + COMMAND xcrun -sdk macosx metallib ${CMAKE_BINARY_DIR}/mtl.air -o ${CMAKE_BINARY_DIR}/mtl.metallib + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal + COMMENT "Building mtl.metallib" + ) endif() diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 47bbdb4ad..86e0b0c78 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -32,47 +32,9 @@ struct ggml_mtl_context { }; // MSL code -NSString * const msl_library_llama = @"\ -#include \n\ -using namespace metal; \n\ - \n\ -#define MAX(x, y) ((x) > (y) ? (x) : (y)) \n\ - \n\ -constant int k_digits [[function_constant(0)]]; \n\ - \n\ -kernel void kernel_add( \n\ - device const float * src0, \n\ - device const float * src1, \n\ - device float * dst, \n\ - uint gid[[thread_position_in_grid]]) { \n\ - dst[gid] = src0[gid] + src1[gid]; \n\ -} \n\ - \n\ -kernel void kernel_relu( \n\ - device const float * src, \n\ - device float * dst, \n\ - uint gid[[thread_position_in_grid]]) { \n\ - dst[gid] = max(0.0f, src[gid]); \n\ -} \n\ - \n\ -kernel void kernel_soft_max( \n\ - device const float * src, \n\ - device float * dst, \n\ - uint gid[[thread_position_in_grid]]) { \n\ - float max = 0.0f; \n\ - for (int i = 0; i < k_digits; i++) { \n\ - max = MAX(max, src[i]); \n\ - } \n\ - float sum = 0.0f; \n\ - for (int i = 0; i < k_digits; i++) { \n\ - dst[i] = exp(src[i] - max); \n\ - sum += dst[i]; \n\ - } \n\ - for (int i = 0; i < k_digits; i++) { \n\ - dst[i] /= sum; \n\ - } \n\ -} \n\ -"; +// TODO: move the contents here when ready +// for now it is easier to work in a separate file +NSString * const msl_library_llama = @"see mtl.metal"; struct ggml_mtl_context * llama_mtl_init( struct ggml_context * ctx_data, @@ -98,15 +60,50 @@ struct ggml_mtl_context * llama_mtl_init( GGML_ASSERT(false && "MPS not supported"); } +#if 0 // compile from source string and show compile log { NSError * error = nil; + ctx->library = [ctx->device newLibraryWithSource:msl_library_llama options:nil error:&error]; if (error) { fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); exit(1); } } +#elif 0 + // this does not work !?!?! + + // load library from "mtl.metallib" + { + NSError * error = nil; + + NSString * path = [[NSBundle mainBundle] pathForResource:@"./mtl" ofType:@"metallib"]; + ctx->library = [ctx->device newLibraryWithFile:path error:&error]; + if (error) { + fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); + exit(1); + } + } +#else + // read the source from "../examples/mtl/mtl.metal" into a string and use newLibraryWithSource + { + NSError * error = nil; + + NSString * path = [[NSBundle mainBundle] pathForResource:@"../examples/mtl/mtl" ofType:@"metal"]; + NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; + if (error) { + fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); + exit(1); + } + + ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; + if (error) { + fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); + exit(1); + } + } +#endif // load kernels { diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal new file mode 100644 index 000000000..e9597336c --- /dev/null +++ b/examples/mtl/mtl.metal @@ -0,0 +1,40 @@ +#include + +using namespace metal; + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + +constant int k_digits [[function_constant(0)]]; + +kernel void kernel_add( + device const float * src0, + device const float * src1, + device float * dst, + uint gid[[thread_position_in_grid]]) { + dst[gid] = src0[gid] + src1[gid]; +} + +kernel void kernel_relu( + device const float * src, + device float * dst, + uint gid[[thread_position_in_grid]]) { + dst[gid] = max(0.0f, src[gid]); +} + +kernel void kernel_soft_max( + device const float * src, + device float * dst, + uint gid[[thread_position_in_grid]]) { + float max = 0.0f; + for (int i = 0; i < k_digits; i++) { + max = MAX(max, src[i]); + } + float sum = 0.0f; + for (int i = 0; i < k_digits; i++) { + dst[i] = exp(src[i] - max); + sum += dst[i]; + } + for (int i = 0; i < k_digits; i++) { + dst[i] /= sum; + } +}