mtl : move MSL code into separate file for easy editing
This commit is contained in:
parent
897d6d8e8f
commit
248a8c3379
3 changed files with 94 additions and 44 deletions
|
@ -2,9 +2,9 @@ if (APPLE)
|
||||||
#
|
#
|
||||||
# mtl
|
# mtl
|
||||||
|
|
||||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||||
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
|
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
|
||||||
|
|
||||||
set(TEST_TARGET mtl)
|
set(TEST_TARGET mtl)
|
||||||
|
@ -16,5 +16,18 @@ if (APPLE)
|
||||||
${METALKIT_FRAMEWORK}
|
${METALKIT_FRAMEWORK}
|
||||||
${METALPERFORMANCE_FRAMEWORK}
|
${METALPERFORMANCE_FRAMEWORK}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: temporary until the kernels are ready
|
||||||
|
# custom command to build mtl.metal into a library
|
||||||
|
# depends on the mtl.metal file
|
||||||
|
add_custom_target(mtl.metallib-tmp ALL DEPENDS ${CMAKE_BINARY_DIR}/mtl.metallib)
|
||||||
|
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT ${CMAKE_BINARY_DIR}/mtl.metallib
|
||||||
|
COMMAND xcrun -sdk macosx metal -c ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal -o ${CMAKE_BINARY_DIR}/mtl.air
|
||||||
|
COMMAND xcrun -sdk macosx metallib ${CMAKE_BINARY_DIR}/mtl.air -o ${CMAKE_BINARY_DIR}/mtl.metallib
|
||||||
|
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/mtl.metal
|
||||||
|
COMMENT "Building mtl.metallib"
|
||||||
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -32,47 +32,9 @@ struct ggml_mtl_context {
|
||||||
};
|
};
|
||||||
|
|
||||||
// MSL code
|
// MSL code
|
||||||
NSString * const msl_library_llama = @"\
|
// TODO: move the contents here when ready
|
||||||
#include <metal_stdlib> \n\
|
// for now it is easier to work in a separate file
|
||||||
using namespace metal; \n\
|
NSString * const msl_library_llama = @"see mtl.metal";
|
||||||
\n\
|
|
||||||
#define MAX(x, y) ((x) > (y) ? (x) : (y)) \n\
|
|
||||||
\n\
|
|
||||||
constant int k_digits [[function_constant(0)]]; \n\
|
|
||||||
\n\
|
|
||||||
kernel void kernel_add( \n\
|
|
||||||
device const float * src0, \n\
|
|
||||||
device const float * src1, \n\
|
|
||||||
device float * dst, \n\
|
|
||||||
uint gid[[thread_position_in_grid]]) { \n\
|
|
||||||
dst[gid] = src0[gid] + src1[gid]; \n\
|
|
||||||
} \n\
|
|
||||||
\n\
|
|
||||||
kernel void kernel_relu( \n\
|
|
||||||
device const float * src, \n\
|
|
||||||
device float * dst, \n\
|
|
||||||
uint gid[[thread_position_in_grid]]) { \n\
|
|
||||||
dst[gid] = max(0.0f, src[gid]); \n\
|
|
||||||
} \n\
|
|
||||||
\n\
|
|
||||||
kernel void kernel_soft_max( \n\
|
|
||||||
device const float * src, \n\
|
|
||||||
device float * dst, \n\
|
|
||||||
uint gid[[thread_position_in_grid]]) { \n\
|
|
||||||
float max = 0.0f; \n\
|
|
||||||
for (int i = 0; i < k_digits; i++) { \n\
|
|
||||||
max = MAX(max, src[i]); \n\
|
|
||||||
} \n\
|
|
||||||
float sum = 0.0f; \n\
|
|
||||||
for (int i = 0; i < k_digits; i++) { \n\
|
|
||||||
dst[i] = exp(src[i] - max); \n\
|
|
||||||
sum += dst[i]; \n\
|
|
||||||
} \n\
|
|
||||||
for (int i = 0; i < k_digits; i++) { \n\
|
|
||||||
dst[i] /= sum; \n\
|
|
||||||
} \n\
|
|
||||||
} \n\
|
|
||||||
";
|
|
||||||
|
|
||||||
struct ggml_mtl_context * llama_mtl_init(
|
struct ggml_mtl_context * llama_mtl_init(
|
||||||
struct ggml_context * ctx_data,
|
struct ggml_context * ctx_data,
|
||||||
|
@ -98,15 +60,50 @@ struct ggml_mtl_context * llama_mtl_init(
|
||||||
GGML_ASSERT(false && "MPS not supported");
|
GGML_ASSERT(false && "MPS not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if 0
|
||||||
// compile from source string and show compile log
|
// compile from source string and show compile log
|
||||||
{
|
{
|
||||||
NSError * error = nil;
|
NSError * error = nil;
|
||||||
|
|
||||||
ctx->library = [ctx->device newLibraryWithSource:msl_library_llama options:nil error:&error];
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_llama options:nil error:&error];
|
||||||
if (error) {
|
if (error) {
|
||||||
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#elif 0
|
||||||
|
// this does not work !?!?!
|
||||||
|
|
||||||
|
// load library from "mtl.metallib"
|
||||||
|
{
|
||||||
|
NSError * error = nil;
|
||||||
|
|
||||||
|
NSString * path = [[NSBundle mainBundle] pathForResource:@"./mtl" ofType:@"metallib"];
|
||||||
|
ctx->library = [ctx->device newLibraryWithFile:path error:&error];
|
||||||
|
if (error) {
|
||||||
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// read the source from "../examples/mtl/mtl.metal" into a string and use newLibraryWithSource
|
||||||
|
{
|
||||||
|
NSError * error = nil;
|
||||||
|
|
||||||
|
NSString * path = [[NSBundle mainBundle] pathForResource:@"../examples/mtl/mtl" ofType:@"metal"];
|
||||||
|
NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
|
||||||
|
if (error) {
|
||||||
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
||||||
|
if (error) {
|
||||||
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// load kernels
|
// load kernels
|
||||||
{
|
{
|
||||||
|
|
40
examples/mtl/mtl.metal
Normal file
40
examples/mtl/mtl.metal
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
#include <metal_stdlib>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
#define MAX(x, y) ((x) > (y) ? (x) : (y))
|
||||||
|
|
||||||
|
constant int k_digits [[function_constant(0)]];
|
||||||
|
|
||||||
|
kernel void kernel_add(
|
||||||
|
device const float * src0,
|
||||||
|
device const float * src1,
|
||||||
|
device float * dst,
|
||||||
|
uint gid[[thread_position_in_grid]]) {
|
||||||
|
dst[gid] = src0[gid] + src1[gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_relu(
|
||||||
|
device const float * src,
|
||||||
|
device float * dst,
|
||||||
|
uint gid[[thread_position_in_grid]]) {
|
||||||
|
dst[gid] = max(0.0f, src[gid]);
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_soft_max(
|
||||||
|
device const float * src,
|
||||||
|
device float * dst,
|
||||||
|
uint gid[[thread_position_in_grid]]) {
|
||||||
|
float max = 0.0f;
|
||||||
|
for (int i = 0; i < k_digits; i++) {
|
||||||
|
max = MAX(max, src[i]);
|
||||||
|
}
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int i = 0; i < k_digits; i++) {
|
||||||
|
dst[i] = exp(src[i] - max);
|
||||||
|
sum += dst[i];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < k_digits; i++) {
|
||||||
|
dst[i] /= sum;
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue