Merge branch 'master' into gg/flash-attn

This commit is contained in:
Georgi Gerganov 2024-03-22 16:34:34 +02:00
commit 9495d3982d
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
136 changed files with 72720 additions and 64258 deletions

View file

@ -179,8 +179,9 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
//GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CONCAT,
@ -286,6 +287,11 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
id<MTLLibrary> metal_library;
// load library
//
// - first check if the library is embedded
// - then check if the library is in the bundle
// - if not found, load the source and compile it
// - if that fails, return NULL
{
NSBundle * bundle = nil;
#ifdef SWIFT_PACKAGE
@ -293,12 +299,21 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
#else
bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
#endif
NSError * error = nil;
NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
if (libPath != nil) {
#if GGML_METAL_EMBED_LIBRARY
const bool try_metallib = false;
#else
const bool try_metallib = true;
#endif
NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
if (try_metallib && path_lib != nil) {
// pre-compiled library found
NSURL * libURL = [NSURL fileURLWithPath:libPath];
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
NSURL * libURL = [NSURL fileURLWithPath:path_lib];
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
if (error) {
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
@ -311,38 +326,41 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
extern const char ggml_metallib_start[];
extern const char ggml_metallib_end[];
NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
NSString * src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
#else
GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
NSString * sourcePath;
NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
NSString * path_source;
NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, ggmlMetalPathResources ? [ggmlMetalPathResources UTF8String] : "nil");
GGML_METAL_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
if (ggmlMetalPathResources) {
sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"];
if (path_resource) {
path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
} else {
sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
}
if (sourcePath == nil) {
if (path_source == nil) {
GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
sourcePath = @"ggml-metal.metal";
path_source = @"ggml-metal.metal";
}
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
NSString * src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
if (error) {
GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL;
}
#endif
#endif // GGML_METAL_EMBED_LIBRARY
@autoreleasepool {
// dictionary of preprocessor macros
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
#ifdef GGML_QKK_64
prep[@"QK_K"] = @(64);
prep[@"GGML_QKK_64"] = @(1);
#endif
MTLCompileOptions* options = [MTLCompileOptions new];
@ -596,8 +614,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
@ -738,6 +757,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
return true;
default:
return false;
@ -764,7 +786,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
}
}
static bool ggml_metal_graph_compute(
static enum ggml_status ggml_metal_graph_compute(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
@ -1388,6 +1410,14 @@ static bool ggml_metal_graph_compute(
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
@ -1660,8 +1690,8 @@ static bool ggml_metal_graph_compute(
// TODO: make this more general
GGML_ASSERT(n_as <= 8);
// max size of the src1ids array in the kernel stack
GGML_ASSERT(ne11 <= 512);
// max size of the src1ids array in the kernel shared buffer
GGML_ASSERT(ne11 <= 4096);
const int64_t ne20 = src2 ? src2->ne[0] : 0;
const int64_t ne21 = src2 ? src2->ne[1] : 0;
@ -1702,6 +1732,14 @@ static bool ggml_metal_graph_compute(
ne20 % 32 == 0 && ne20 >= 64 &&
ne11 > ne11_mm_min) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (src0->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
id<MTLComputePipelineState> pipeline = nil;
switch (src2->type) {
@ -1759,7 +1797,7 @@ static bool ggml_metal_graph_compute(
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
@ -2532,13 +2570,14 @@ static bool ggml_metal_graph_compute(
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
switch (dstt) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
//case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
//case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
@ -2602,7 +2641,7 @@ static bool ggml_metal_graph_compute(
MTLCommandBufferStatus status = [command_buffer status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
return false;
return GGML_STATUS_FAILED;
}
}
@ -2611,7 +2650,7 @@ static bool ggml_metal_graph_compute(
}
}
return true;
return GGML_STATUS_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////
@ -2919,7 +2958,7 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
UNUSED(backend);
}
GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
return ggml_metal_graph_compute(metal_ctx, cgraph);
@ -2944,6 +2983,12 @@ static struct ggml_backend_i ggml_backend_metal_i = {
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_metal_graph_compute,
/* .supports_op = */ ggml_backend_metal_supports_op,
/* .offload_op = */ NULL,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .event_synchronize = */ NULL,
};
void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data) {