Merge branch 'master' into gg/flash-attn

This commit is contained in:
Georgi Gerganov 2024-01-28 21:53:51 +02:00
commit 0ad44baf33
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
36 changed files with 85118 additions and 89 deletions

View file

@ -24,10 +24,7 @@
#define UNUSED(x) (void)(x)
#define GGML_METAL_MAX_KERNELS 256
struct ggml_metal_kernel {
id<MTLFunction> function;
id<MTLComputePipelineState> pipeline;
};
@ -162,11 +159,10 @@ struct ggml_metal_context {
id<MTLDevice> device;
id<MTLCommandQueue> queue;
id<MTLLibrary> library;
dispatch_queue_t d_queue;
struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
bool support_simdgroup_reduction;
bool support_simdgroup_mm;
@ -249,6 +245,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->queue = [ctx->device newCommandQueue];
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
id<MTLLibrary> metal_library;
// load library
{
NSBundle * bundle = nil;
@ -263,7 +261,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
// pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:libPath];
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
@ -305,7 +303,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
//[options setFastMathEnabled:false];
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
if (error) {
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
@ -370,8 +368,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
{
NSError * error = nil;
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
ctx->kernels[i].function = nil;
for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
ctx->kernels[i].pipeline = nil;
}
@ -383,13 +380,15 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
#define GGML_METAL_ADD_KERNEL(e, name, supported) \
if (supported) { \
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
[metal_function release]; \
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
(int) kernel->pipeline.threadExecutionWidth); \
if (error) { \
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
[metal_library release]; \
return NULL; \
} \
} else { \
@ -521,23 +520,17 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
}
[metal_library release];
return ctx;
}
static void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
if (ctx->kernels[i].pipeline) {
[ctx->kernels[i].pipeline release];
}
if (ctx->kernels[i].function) {
[ctx->kernels[i].function release];
}
for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) {
[ctx->kernels[i].pipeline release];
}
[ctx->library release];
[ctx->queue release];
[ctx->device release];
@ -2504,6 +2497,7 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
/* .get_name = */ ggml_backend_metal_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // TODO: return device.maxBufferLength
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .supports_backend = */ ggml_backend_metal_buffer_type_supports_backend,
/* .is_host = */ ggml_backend_metal_buffer_type_is_host,