diff --git a/examples/mtl/CMakeLists.txt b/examples/mtl/CMakeLists.txt index 4dc0bc596..a8923405f 100644 --- a/examples/mtl/CMakeLists.txt +++ b/examples/mtl/CMakeLists.txt @@ -2,6 +2,28 @@ set(TARGET mtl-export) add_executable(${TARGET} mtl-export.cpp) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) + if(TARGET BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO) endif() + +if (APPLE) + # + # mtl + + find_library(FOUNDATION_LIBRARY Foundation REQUIRED) + find_library(METAL_FRAMEWORK Metal REQUIRED) + find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED) + + set(TEST_TARGET mtl) + add_executable(${TEST_TARGET} mtl.cpp mtl.h mtl.m) + target_link_libraries(${TEST_TARGET} PRIVATE + ggml + ${FOUNDATION_LIBRARY} + ${METAL_FRAMEWORK} + ${METALKIT_FRAMEWORK} + ${METALPERFORMANCE_FRAMEWORK} + ) +endif() + diff --git a/examples/mtl/mtl.cpp b/examples/mtl/mtl.cpp new file mode 100644 index 000000000..68e828d4e --- /dev/null +++ b/examples/mtl/mtl.cpp @@ -0,0 +1,51 @@ +#include "ggml.h" +#include "mtl.h" + +#include +#include +#include + +int main(int argc, char ** argv) { + ggml_time_init(); + + if (argc != 2) { + fprintf(stderr, "Usage: %s llama.ggml\n", argv[0]); + return -1; + } + + const char * fname_cgraph = argv[1]; + + // load the compute graph + struct ggml_context * ctx_data = NULL; + struct ggml_context * ctx_eval = NULL; + + struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval); + gf.n_threads = 1; + + // allocate work context + static size_t buf_size = gf.work_size; // TODO + static void * buf = malloc(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx_work = ggml_init(params); + + // this allocates all Metal resources and memory buffers + auto * ctx_mtl = llama_mtl_init(ctx_data, ctx_eval, ctx_work, &gf); + + // the actual inference happens here + llama_mtl_eval(ctx_mtl, &gf); + + llama_mtl_free(ctx_mtl); + + ggml_free(ctx_work); + ggml_free(ctx_data); + ggml_free(ctx_eval); + + return 0; +} + diff --git a/examples/mtl/mtl.h b/examples/mtl/mtl.h new file mode 100644 index 000000000..a40d57111 --- /dev/null +++ b/examples/mtl/mtl.h @@ -0,0 +1,28 @@ +#pragma once + +struct ggml_context; +struct ggml_cgraph; + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_mtl_context; + +struct ggml_mtl_context * llama_mtl_init( + struct ggml_context * ctx_data, + struct ggml_context * ctx_eval, + struct ggml_context * ctx_work, + struct ggml_cgraph * gf); + +void llama_mtl_free(struct ggml_mtl_context * ctx); + +// return 0 on success +int llama_mtl_eval( + struct ggml_mtl_context * ctx, + struct ggml_cgraph * gf); + +#ifdef __cplusplus +} +#endif + diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m new file mode 100644 index 000000000..58f1f0371 --- /dev/null +++ b/examples/mtl/mtl.m @@ -0,0 +1,357 @@ +#import "mtl.h" + +#import "ggml.h" + +#import +#import +#import + +struct ggml_mtl_context { + struct ggml_context * ctx_data; + struct ggml_context * ctx_eval; + struct ggml_context * ctx_work; + + id device; + id queue; + id library; + + id buffer_data; + id buffer_eval; + + id out; + + // custom kernels + id function_add; + id pipeline_add; + + id function_relu; + id pipeline_relu; + + id function_soft_max; + id pipeline_soft_max; +}; + +// MSL code +NSString * const msl_library_llama = @"\ +#include \n\ +using namespace metal; \n\ + \n\ +#define MAX(x, y) ((x) > (y) ? (x) : (y)) \n\ + \n\ +constant int k_digits [[function_constant(0)]]; \n\ + \n\ +kernel void kernel_add( \n\ + device const float * src0, \n\ + device const float * src1, \n\ + device float * dst, \n\ + uint gid[[thread_position_in_grid]]) { \n\ + dst[gid] = src0[gid] + src1[gid]; \n\ +} \n\ + \n\ +kernel void kernel_relu( \n\ + device const float * src, \n\ + device float * dst, \n\ + uint gid[[thread_position_in_grid]]) { \n\ + dst[gid] = max(0.0f, src[gid]); \n\ +} \n\ + \n\ +kernel void kernel_soft_max( \n\ + device const float * src, \n\ + device float * dst, \n\ + uint gid[[thread_position_in_grid]]) { \n\ + float max = 0.0f; \n\ + for (int i = 0; i < k_digits; i++) { \n\ + max = MAX(max, src[i]); \n\ + } \n\ + float sum = 0.0f; \n\ + for (int i = 0; i < k_digits; i++) { \n\ + dst[i] = exp(src[i] - max); \n\ + sum += dst[i]; \n\ + } \n\ + for (int i = 0; i < k_digits; i++) { \n\ + dst[i] /= sum; \n\ + } \n\ +} \n\ +"; + +struct ggml_mtl_context * llama_mtl_init( + struct ggml_context * ctx_data, + struct ggml_context * ctx_eval, + struct ggml_context * ctx_work, + struct ggml_cgraph * gf) { + fprintf(stderr, "%s: allocating\n", __func__); + + struct ggml_mtl_context * ctx = malloc(sizeof(struct ggml_mtl_context)); + + ctx->ctx_data = ctx_data; + ctx->ctx_eval = ctx_eval; + ctx->ctx_work = ctx_work; + + ctx->device = MTLCreateSystemDefaultDevice(); + ctx->queue = [ctx->device newCommandQueue]; + + // determine if we can use MPS + if (MPSSupportsMTLDevice(ctx->device)) { + fprintf(stderr, "%s: using MPS\n", __func__); + } else { + fprintf(stderr, "%s: not using MPS\n", __func__); + GGML_ASSERT(false && "MPS not supported"); + } + + // compile from source string and show compile log + { + NSError * error = nil; + ctx->library = [ctx->device newLibraryWithSource:msl_library_llama options:nil error:&error]; + if (error) { + fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]); + exit(1); + } + } + + // load kernels + { + const int k_digits = 123; + + MTLFunctionConstantValues * constants = [MTLFunctionConstantValues new]; + [constants setConstantValue:&k_digits type:MTLDataTypeInt withName:@"k_digits"]; + + ctx->function_add = [ctx->library newFunctionWithName:@"kernel_add"]; + ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil]; + fprintf(stderr, "%s: loaded kernel_add: %p\n", __func__, ctx->pipeline_add); + + ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"]; + ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil]; + fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, ctx->pipeline_relu); + + ctx->function_soft_max = [ctx->library newFunctionWithName:@"kernel_soft_max" constantValues:constants error:nil]; + ctx->pipeline_soft_max = [ctx->device newComputePipelineStateWithFunction:ctx->function_soft_max error:nil]; + fprintf(stderr, "%s: loaded kernel_soft_max: %p\n", __func__, ctx->pipeline_soft_max); + } + + // MTLBuffer approach + + // pin ctx_data memory to GPU + // use MTLStorageModeShared to allow us to initialize the weights from the CPU + // TODO: how to use MTLStorageModeManaged? + // TODO: see if we can avoid this copy somehow + { + const void * mem_buffer = ggml_get_mem_buffer(ctx_data); + const size_t mem_size = ggml_get_mem_size(ctx_data); + + ctx->buffer_data = [ctx->device newBufferWithBytes:mem_buffer length:mem_size options:MTLResourceStorageModeShared]; + + fprintf(stderr, "%s: allocated data buffer, size = %8.2f MB\n", __func__, mem_size / 1024.0 / 1024.0); + } + + // pin ctx_eval memory to GPU + // this buffer will be used for the intermediate results of the evaluation + { + const size_t mem_size = ggml_get_mem_size(ctx_eval); + + ctx->buffer_eval = [ctx->device newBufferWithLength:mem_size options:MTLResourceStorageModePrivate]; + + fprintf(stderr, "%s: allocated eval buffer, size = %8.2f MB\n", __func__, mem_size / 1024.0 / 1024.0); + } + + // allocate buffer for result extraction + { + const size_t mem_size = ggml_nbytes(gf->nodes[gf->n_nodes - 1]); + + ctx->out = [ctx->device newBufferWithLength:mem_size options:MTLResourceStorageModeShared]; + + fprintf(stderr, "%s: allocated out buffer, size = %8.2f MB\n", __func__, mem_size / 1024.0 / 1024.0); + } + + return ctx; +} + +void llama_mtl_free(struct ggml_mtl_context * ctx) { + fprintf(stderr, "%s: deallocating\n", __func__); + + free(ctx); +} + +// get data / eval buffer + offset +id llama_mtl_get_buffer(struct ggml_mtl_context * ctx, struct ggml_tensor * t, size_t * offs) { + const int64_t offs_data = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_data); + const int64_t offs_eval = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_eval); + + const bool is_data = (offs_eval < 0) || (offs_data >= 0 && offs_data < offs_eval); + + const size_t t_size = ggml_nbytes(t); + const size_t t_offs = is_data ? offs_data : offs_eval; + + id result; + + if (is_data) { + fprintf(stderr, "%s: data tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size); + result = ctx->buffer_data; + } else { + fprintf(stderr, "%s: eval tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size); + result = ctx->buffer_eval; + } + + if (result == nil) { + fprintf(stderr, "%s: error: buffer is nil\n", __func__); + GGML_ASSERT(false); + } + + if (offs != nil) { + *offs = t_offs; + } + + return result; +} + +int llama_mtl_eval( + struct ggml_mtl_context * ctx, + struct ggml_cgraph * gf) { + fprintf(stderr, "%s: evaluating\n", __func__); + + id command_buffer = [ctx->queue commandBuffer]; + id encoder = nil; + + size_t offs_src0; + size_t offs_src1; + size_t offs_dst; + + // copy the input data to the GPU + { + struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "input"); + + id id_dst = llama_mtl_get_buffer(ctx, inp, &offs_src0); + + memcpy(id_dst.contents + offs_src0, inp->data, ggml_nbytes(inp)); + } + + for (int i = 0; i < gf->n_nodes; ++i) { + fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + + switch (gf->nodes[i]->op) { + case GGML_OP_ADD: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + [encoder setComputePipelineState:ctx->pipeline_add]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + const int64_t n = ggml_nelements(gf->nodes[i]); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_RELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + [encoder setComputePipelineState:ctx->pipeline_relu]; + [encoder setBuffer:id_src offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(gf->nodes[i]); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SOFT_MAX: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + [encoder setComputePipelineState:ctx->pipeline_soft_max]; + [encoder setBuffer:id_src offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_MUL_MAT: + { + if (encoder != nil) { + [encoder endEncoding]; + encoder = nil; + } + + // use MPSMatrixMultiplication + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + const int64_t ncols0 = gf->nodes[i]->src0->ne[0]; + const int64_t nrows0 = gf->nodes[i]->src0->ne[1]; + + const int64_t ncols1 = gf->nodes[i]->src1->ne[0]; + const int64_t nrows1 = gf->nodes[i]->src1->ne[1]; + + const int64_t ncols2 = gf->nodes[i]->ne[0]; + const int64_t nrows2 = gf->nodes[i]->ne[1]; + + GGML_ASSERT(ncols0 == ncols1); + + MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor + matrixDescriptorWithRows:nrows0 columns:ncols0 rowBytes:gf->nodes[i]->src0->nb[1] dataType:MPSDataTypeFloat32]; + MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor + matrixDescriptorWithRows:nrows1 columns:ncols1 rowBytes:gf->nodes[i]->src1->nb[1] dataType:MPSDataTypeFloat32]; + MPSMatrixDescriptor * desc2 = [MPSMatrixDescriptor + matrixDescriptorWithRows:nrows2 columns:ncols2 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32]; + + MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0 descriptor:desc0]; + MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1 descriptor:desc1]; + MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst descriptor:desc2]; + + MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] initWithDevice:ctx->device + transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0]; + + [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst]; + } break; + default: + fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + GGML_ASSERT(false); + return -1; + } + } + + // extract results from the GPU + { + if (encoder != nil) { + [encoder endEncoding]; + encoder = nil; + } + + struct ggml_tensor * out = gf->nodes[gf->n_nodes - 1]; + + id id_src = llama_mtl_get_buffer(ctx, out, &offs_src0); + id id_dst = ctx->out; + + id encoder_blit = [command_buffer blitCommandEncoder]; + [encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)]; + [encoder_blit endEncoding]; + } + + [command_buffer commit]; + [command_buffer waitUntilCompleted]; + + { + const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime]; + fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed); + } + + // TODO + const float * logits = ctx->out.contents; + + return 0; +}