mtl : move MSL code into separate file for easy editing
This commit is contained in:
parent
897d6d8e8f
commit
248a8c3379
3 changed files with 94 additions and 44 deletions
|
@ -2,9 +2,9 @@ if (APPLE)
|
|||
#
|
||||
# mtl
|
||||
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
|
||||
|
||||
set(TEST_TARGET mtl)
|
||||
|
@ -16,5 +16,18 @@ if (APPLE)
|
|||
${METALKIT_FRAMEWORK}
|
||||
${METALPERFORMANCE_FRAMEWORK}
|
||||
)
|
||||
|
||||
# TODO: temporary until the kernels are ready
|
||||
# custom command to build mtl.metal into a library
|
||||
# depends on the mtl.metal file
|
||||
add_custom_target(mtl.metallib-tmp ALL DEPENDS ${CMAKE_BINARY_DIR}/mtl.metallib)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${CMAKE_BINARY_DIR}/mtl.metallib
|
||||
COMMAND xcrun -sdk macosx metal -c ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal -o ${CMAKE_BINARY_DIR}/mtl.air
|
||||
COMMAND xcrun -sdk macosx metallib ${CMAKE_BINARY_DIR}/mtl.air -o ${CMAKE_BINARY_DIR}/mtl.metallib
|
||||
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal
|
||||
COMMENT "Building mtl.metallib"
|
||||
)
|
||||
endif()
|
||||
|
||||
|
|
|
@ -32,47 +32,9 @@ struct ggml_mtl_context {
|
|||
};
|
||||
|
||||
// MSL code
|
||||
NSString * const msl_library_llama = @"\
|
||||
#include <metal_stdlib> \n\
|
||||
using namespace metal; \n\
|
||||
\n\
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y)) \n\
|
||||
\n\
|
||||
constant int k_digits [[function_constant(0)]]; \n\
|
||||
\n\
|
||||
kernel void kernel_add( \n\
|
||||
device const float * src0, \n\
|
||||
device const float * src1, \n\
|
||||
device float * dst, \n\
|
||||
uint gid[[thread_position_in_grid]]) { \n\
|
||||
dst[gid] = src0[gid] + src1[gid]; \n\
|
||||
} \n\
|
||||
\n\
|
||||
kernel void kernel_relu( \n\
|
||||
device const float * src, \n\
|
||||
device float * dst, \n\
|
||||
uint gid[[thread_position_in_grid]]) { \n\
|
||||
dst[gid] = max(0.0f, src[gid]); \n\
|
||||
} \n\
|
||||
\n\
|
||||
kernel void kernel_soft_max( \n\
|
||||
device const float * src, \n\
|
||||
device float * dst, \n\
|
||||
uint gid[[thread_position_in_grid]]) { \n\
|
||||
float max = 0.0f; \n\
|
||||
for (int i = 0; i < k_digits; i++) { \n\
|
||||
max = MAX(max, src[i]); \n\
|
||||
} \n\
|
||||
float sum = 0.0f; \n\
|
||||
for (int i = 0; i < k_digits; i++) { \n\
|
||||
dst[i] = exp(src[i] - max); \n\
|
||||
sum += dst[i]; \n\
|
||||
} \n\
|
||||
for (int i = 0; i < k_digits; i++) { \n\
|
||||
dst[i] /= sum; \n\
|
||||
} \n\
|
||||
} \n\
|
||||
";
|
||||
// TODO: move the contents here when ready
|
||||
// for now it is easier to work in a separate file
|
||||
NSString * const msl_library_llama = @"see mtl.metal";
|
||||
|
||||
struct ggml_mtl_context * llama_mtl_init(
|
||||
struct ggml_context * ctx_data,
|
||||
|
@ -98,15 +60,50 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
GGML_ASSERT(false && "MPS not supported");
|
||||
}
|
||||
|
||||
#if 0
|
||||
// compile from source string and show compile log
|
||||
{
|
||||
NSError * error = nil;
|
||||
|
||||
ctx->library = [ctx->device newLibraryWithSource:msl_library_llama options:nil error:&error];
|
||||
if (error) {
|
||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
#elif 0
|
||||
// this does not work !?!?!
|
||||
|
||||
// load library from "mtl.metallib"
|
||||
{
|
||||
NSError * error = nil;
|
||||
|
||||
NSString * path = [[NSBundle mainBundle] pathForResource:@"./mtl" ofType:@"metallib"];
|
||||
ctx->library = [ctx->device newLibraryWithFile:path error:&error];
|
||||
if (error) {
|
||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
#else
|
||||
// read the source from "../examples/mtl/mtl.metal" into a string and use newLibraryWithSource
|
||||
{
|
||||
NSError * error = nil;
|
||||
|
||||
NSString * path = [[NSBundle mainBundle] pathForResource:@"../examples/mtl/mtl" ofType:@"metal"];
|
||||
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
||||
if (error) {
|
||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
||||
if (error) {
|
||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// load kernels
|
||||
{
|
||||
|
|
40
examples/mtl/mtl.metal
Normal file
40
examples/mtl/mtl.metal
Normal file
|
@ -0,0 +1,40 @@
|
|||
#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||
|
||||
constant int k_digits [[function_constant(0)]];
|
||||
|
||||
kernel void kernel_add(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
dst[gid] = src0[gid] + src1[gid];
|
||||
}
|
||||
|
||||
kernel void kernel_relu(
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
dst[gid] = max(0.0f, src[gid]);
|
||||
}
|
||||
|
||||
kernel void kernel_soft_max(
|
||||
device const float * src,
|
||||
device float * dst,
|
||||
uint gid[[thread_position_in_grid]]) {
|
||||
float max = 0.0f;
|
||||
for (int i = 0; i < k_digits; i++) {
|
||||
max = MAX(max, src[i]);
|
||||
}
|
||||
float sum = 0.0f;
|
||||
for (int i = 0; i < k_digits; i++) {
|
||||
dst[i] = exp(src[i] - max);
|
||||
sum += dst[i];
|
||||
}
|
||||
for (int i = 0; i < k_digits; i++) {
|
||||
dst[i] /= sum;
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue