mtl : move MSL code into separate file for easy editing

This commit is contained in:
Georgi Gerganov 2023-05-29 22:26:40 +03:00
parent 897d6d8e8f
commit 248a8c3379
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 94 additions and 44 deletions

View file

@ -16,5 +16,18 @@ if (APPLE)
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
# TODO: temporary until the kernels are ready
# custom command to build mtl.metal into a library
# depends on the mtl.metal file
add_custom_target(mtl.metallib-tmp ALL DEPENDS ${CMAKE_BINARY_DIR}/mtl.metallib)
add_custom_command(
OUTPUT ${CMAKE_BINARY_DIR}/mtl.metallib
COMMAND xcrun -sdk macosx metal -c ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal -o ${CMAKE_BINARY_DIR}/mtl.air
COMMAND xcrun -sdk macosx metallib ${CMAKE_BINARY_DIR}/mtl.air -o ${CMAKE_BINARY_DIR}/mtl.metallib
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal
COMMENT "Building mtl.metallib"
)
endif()

View file

@ -32,47 +32,9 @@ struct ggml_mtl_context {
};
// MSL code
NSString * const msl_library_llama = @"\
#include <metal_stdlib> \n\
using namespace metal; \n\
\n\
#define MAX(x, y) ((x) > (y) ? (x) : (y)) \n\
\n\
constant int k_digits [[function_constant(0)]]; \n\
\n\
kernel void kernel_add( \n\
device const float * src0, \n\
device const float * src1, \n\
device float * dst, \n\
uint gid[[thread_position_in_grid]]) { \n\
dst[gid] = src0[gid] + src1[gid]; \n\
} \n\
\n\
kernel void kernel_relu( \n\
device const float * src, \n\
device float * dst, \n\
uint gid[[thread_position_in_grid]]) { \n\
dst[gid] = max(0.0f, src[gid]); \n\
} \n\
\n\
kernel void kernel_soft_max( \n\
device const float * src, \n\
device float * dst, \n\
uint gid[[thread_position_in_grid]]) { \n\
float max = 0.0f; \n\
for (int i = 0; i < k_digits; i++) { \n\
max = MAX(max, src[i]); \n\
} \n\
float sum = 0.0f; \n\
for (int i = 0; i < k_digits; i++) { \n\
dst[i] = exp(src[i] - max); \n\
sum += dst[i]; \n\
} \n\
for (int i = 0; i < k_digits; i++) { \n\
dst[i] /= sum; \n\
} \n\
} \n\
";
// TODO: move the contents here when ready
// for now it is easier to work in a separate file
NSString * const msl_library_llama = @"see mtl.metal";
struct ggml_mtl_context * llama_mtl_init(
struct ggml_context * ctx_data,
@ -98,15 +60,50 @@ struct ggml_mtl_context * llama_mtl_init(
GGML_ASSERT(false && "MPS not supported");
}
#if 0
// compile from source string and show compile log
{
NSError * error = nil;
ctx->library = [ctx->device newLibraryWithSource:msl_library_llama options:nil error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
}
}
#elif 0
// this does not work !?!?!
// load library from "mtl.metallib"
{
NSError * error = nil;
NSString * path = [[NSBundle mainBundle] pathForResource:@"./mtl" ofType:@"metallib"];
ctx->library = [ctx->device newLibraryWithFile:path error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
}
}
#else
// read the source from "../examples/mtl/mtl.metal" into a string and use newLibraryWithSource
{
NSError * error = nil;
NSString * path = [[NSBundle mainBundle] pathForResource:@"../examples/mtl/mtl" ofType:@"metal"];
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
}
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
if (error) {
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
exit(1);
}
}
#endif
// load kernels
{

40
examples/mtl/mtl.metal Normal file
View file

@ -0,0 +1,40 @@
#include <metal_stdlib>
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
constant int k_digits [[function_constant(0)]];
kernel void kernel_add(
device const float * src0,
device const float * src1,
device float * dst,
uint gid[[thread_position_in_grid]]) {
dst[gid] = src0[gid] + src1[gid];
}
kernel void kernel_relu(
device const float * src,
device float * dst,
uint gid[[thread_position_in_grid]]) {
dst[gid] = max(0.0f, src[gid]);
}
kernel void kernel_soft_max(
device const float * src,
device float * dst,
uint gid[[thread_position_in_grid]]) {
float max = 0.0f;
for (int i = 0; i < k_digits; i++) {
max = MAX(max, src[i]);
}
float sum = 0.0f;
for (int i = 0; i < k_digits; i++) {
dst[i] = exp(src[i] - max);
sum += dst[i];
}
for (int i = 0; i < k_digits; i++) {
dst[i] /= sum;
}
}