mtl : adapt the MNIST example as starter
This commit is contained in:
parent
98c267fc77
commit
b23fe8c9c7
4 changed files with 458 additions and 0 deletions
|
@ -2,6 +2,28 @@ set(TARGET mtl-export)
|
||||||
add_executable(${TARGET} mtl-export.cpp)
|
add_executable(${TARGET} mtl-export.cpp)
|
||||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||||
|
|
||||||
if(TARGET BUILD_INFO)
|
if(TARGET BUILD_INFO)
|
||||||
add_dependencies(${TARGET} BUILD_INFO)
|
add_dependencies(${TARGET} BUILD_INFO)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (APPLE)
|
||||||
|
#
|
||||||
|
# mtl
|
||||||
|
|
||||||
|
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
||||||
|
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
||||||
|
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
||||||
|
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
|
||||||
|
|
||||||
|
set(TEST_TARGET mtl)
|
||||||
|
add_executable(${TEST_TARGET} mtl.cpp mtl.h mtl.m)
|
||||||
|
target_link_libraries(${TEST_TARGET} PRIVATE
|
||||||
|
ggml
|
||||||
|
${FOUNDATION_LIBRARY}
|
||||||
|
${METAL_FRAMEWORK}
|
||||||
|
${METALKIT_FRAMEWORK}
|
||||||
|
${METALPERFORMANCE_FRAMEWORK}
|
||||||
|
)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
51
examples/mtl/mtl.cpp
Normal file
51
examples/mtl/mtl.cpp
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "mtl.h"
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cstring>
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
ggml_time_init();
|
||||||
|
|
||||||
|
if (argc != 2) {
|
||||||
|
fprintf(stderr, "Usage: %s llama.ggml\n", argv[0]);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * fname_cgraph = argv[1];
|
||||||
|
|
||||||
|
// load the compute graph
|
||||||
|
struct ggml_context * ctx_data = NULL;
|
||||||
|
struct ggml_context * ctx_eval = NULL;
|
||||||
|
|
||||||
|
struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
|
||||||
|
gf.n_threads = 1;
|
||||||
|
|
||||||
|
// allocate work context
|
||||||
|
static size_t buf_size = gf.work_size; // TODO
|
||||||
|
static void * buf = malloc(buf_size);
|
||||||
|
|
||||||
|
struct ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ buf_size,
|
||||||
|
/*.mem_buffer =*/ buf,
|
||||||
|
/*.no_alloc =*/ false,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_context * ctx_work = ggml_init(params);
|
||||||
|
|
||||||
|
// this allocates all Metal resources and memory buffers
|
||||||
|
auto * ctx_mtl = llama_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);
|
||||||
|
|
||||||
|
// the actual inference happens here
|
||||||
|
llama_mtl_eval(ctx_mtl, &gf);
|
||||||
|
|
||||||
|
llama_mtl_free(ctx_mtl);
|
||||||
|
|
||||||
|
ggml_free(ctx_work);
|
||||||
|
ggml_free(ctx_data);
|
||||||
|
ggml_free(ctx_eval);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
28
examples/mtl/mtl.h
Normal file
28
examples/mtl/mtl.h
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
struct ggml_context;
|
||||||
|
struct ggml_cgraph;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct ggml_mtl_context;
|
||||||
|
|
||||||
|
struct ggml_mtl_context * llama_mtl_init(
|
||||||
|
struct ggml_context * ctx_data,
|
||||||
|
struct ggml_context * ctx_eval,
|
||||||
|
struct ggml_context * ctx_work,
|
||||||
|
struct ggml_cgraph * gf);
|
||||||
|
|
||||||
|
void llama_mtl_free(struct ggml_mtl_context * ctx);
|
||||||
|
|
||||||
|
// return 0 on success
|
||||||
|
int llama_mtl_eval(
|
||||||
|
struct ggml_mtl_context * ctx,
|
||||||
|
struct ggml_cgraph * gf);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
357
examples/mtl/mtl.m
Normal file
357
examples/mtl/mtl.m
Normal file
|
@ -0,0 +1,357 @@
|
||||||
|
#import "mtl.h"
|
||||||
|
|
||||||
|
#import "ggml.h"
|
||||||
|
|
||||||
|
#import <Foundation/Foundation.h>
|
||||||
|
#import <Metal/Metal.h>
|
||||||
|
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||||
|
|
||||||
|
struct ggml_mtl_context {
|
||||||
|
struct ggml_context * ctx_data;
|
||||||
|
struct ggml_context * ctx_eval;
|
||||||
|
struct ggml_context * ctx_work;
|
||||||
|
|
||||||
|
id<MTLDevice> device;
|
||||||
|
id<MTLCommandQueue> queue;
|
||||||
|
id<MTLLibrary> library;
|
||||||
|
|
||||||
|
id<MTLBuffer> buffer_data;
|
||||||
|
id<MTLBuffer> buffer_eval;
|
||||||
|
|
||||||
|
id<MTLBuffer> out;
|
||||||
|
|
||||||
|
// custom kernels
|
||||||
|
id<MTLFunction> function_add;
|
||||||
|
id<MTLComputePipelineState> pipeline_add;
|
||||||
|
|
||||||
|
id<MTLFunction> function_relu;
|
||||||
|
id<MTLComputePipelineState> pipeline_relu;
|
||||||
|
|
||||||
|
id<MTLFunction> function_soft_max;
|
||||||
|
id<MTLComputePipelineState> pipeline_soft_max;
|
||||||
|
};
|
||||||
|
|
||||||
|
// MSL code
|
||||||
|
NSString * const msl_library_llama = @"\
|
||||||
|
#include <metal_stdlib> \n\
|
||||||
|
using namespace metal; \n\
|
||||||
|
\n\
|
||||||
|
#define MAX(x, y) ((x) > (y) ? (x) : (y)) \n\
|
||||||
|
\n\
|
||||||
|
constant int k_digits [[function_constant(0)]]; \n\
|
||||||
|
\n\
|
||||||
|
kernel void kernel_add( \n\
|
||||||
|
device const float * src0, \n\
|
||||||
|
device const float * src1, \n\
|
||||||
|
device float * dst, \n\
|
||||||
|
uint gid[[thread_position_in_grid]]) { \n\
|
||||||
|
dst[gid] = src0[gid] + src1[gid]; \n\
|
||||||
|
} \n\
|
||||||
|
\n\
|
||||||
|
kernel void kernel_relu( \n\
|
||||||
|
device const float * src, \n\
|
||||||
|
device float * dst, \n\
|
||||||
|
uint gid[[thread_position_in_grid]]) { \n\
|
||||||
|
dst[gid] = max(0.0f, src[gid]); \n\
|
||||||
|
} \n\
|
||||||
|
\n\
|
||||||
|
kernel void kernel_soft_max( \n\
|
||||||
|
device const float * src, \n\
|
||||||
|
device float * dst, \n\
|
||||||
|
uint gid[[thread_position_in_grid]]) { \n\
|
||||||
|
float max = 0.0f; \n\
|
||||||
|
for (int i = 0; i < k_digits; i++) { \n\
|
||||||
|
max = MAX(max, src[i]); \n\
|
||||||
|
} \n\
|
||||||
|
float sum = 0.0f; \n\
|
||||||
|
for (int i = 0; i < k_digits; i++) { \n\
|
||||||
|
dst[i] = exp(src[i] - max); \n\
|
||||||
|
sum += dst[i]; \n\
|
||||||
|
} \n\
|
||||||
|
for (int i = 0; i < k_digits; i++) { \n\
|
||||||
|
dst[i] /= sum; \n\
|
||||||
|
} \n\
|
||||||
|
} \n\
|
||||||
|
";
|
||||||
|
|
||||||
|
struct ggml_mtl_context * llama_mtl_init(
|
||||||
|
struct ggml_context * ctx_data,
|
||||||
|
struct ggml_context * ctx_eval,
|
||||||
|
struct ggml_context * ctx_work,
|
||||||
|
struct ggml_cgraph * gf) {
|
||||||
|
fprintf(stderr, "%s: allocating\n", __func__);
|
||||||
|
|
||||||
|
struct ggml_mtl_context * ctx = malloc(sizeof(struct ggml_mtl_context));
|
||||||
|
|
||||||
|
ctx->ctx_data = ctx_data;
|
||||||
|
ctx->ctx_eval = ctx_eval;
|
||||||
|
ctx->ctx_work = ctx_work;
|
||||||
|
|
||||||
|
ctx->device = MTLCreateSystemDefaultDevice();
|
||||||
|
ctx->queue = [ctx->device newCommandQueue];
|
||||||
|
|
||||||
|
// determine if we can use MPS
|
||||||
|
if (MPSSupportsMTLDevice(ctx->device)) {
|
||||||
|
fprintf(stderr, "%s: using MPS\n", __func__);
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: not using MPS\n", __func__);
|
||||||
|
GGML_ASSERT(false && "MPS not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
// compile from source string and show compile log
|
||||||
|
{
|
||||||
|
NSError * error = nil;
|
||||||
|
ctx->library = [ctx->device newLibraryWithSource:msl_library_llama options:nil error:&error];
|
||||||
|
if (error) {
|
||||||
|
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// load kernels
|
||||||
|
{
|
||||||
|
const int k_digits = 123;
|
||||||
|
|
||||||
|
MTLFunctionConstantValues * constants = [MTLFunctionConstantValues new];
|
||||||
|
[constants setConstantValue:&k_digits type:MTLDataTypeInt withName:@"k_digits"];
|
||||||
|
|
||||||
|
ctx->function_add = [ctx->library newFunctionWithName:@"kernel_add"];
|
||||||
|
ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil];
|
||||||
|
fprintf(stderr, "%s: loaded kernel_add: %p\n", __func__, ctx->pipeline_add);
|
||||||
|
|
||||||
|
ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"];
|
||||||
|
ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil];
|
||||||
|
fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, ctx->pipeline_relu);
|
||||||
|
|
||||||
|
ctx->function_soft_max = [ctx->library newFunctionWithName:@"kernel_soft_max" constantValues:constants error:nil];
|
||||||
|
ctx->pipeline_soft_max = [ctx->device newComputePipelineStateWithFunction:ctx->function_soft_max error:nil];
|
||||||
|
fprintf(stderr, "%s: loaded kernel_soft_max: %p\n", __func__, ctx->pipeline_soft_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
// MTLBuffer approach
|
||||||
|
|
||||||
|
// pin ctx_data memory to GPU
|
||||||
|
// use MTLStorageModeShared to allow us to initialize the weights from the CPU
|
||||||
|
// TODO: how to use MTLStorageModeManaged?
|
||||||
|
// TODO: see if we can avoid this copy somehow
|
||||||
|
{
|
||||||
|
const void * mem_buffer = ggml_get_mem_buffer(ctx_data);
|
||||||
|
const size_t mem_size = ggml_get_mem_size(ctx_data);
|
||||||
|
|
||||||
|
ctx->buffer_data = [ctx->device newBufferWithBytes:mem_buffer length:mem_size options:MTLResourceStorageModeShared];
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: allocated data buffer, size = %8.2f MB\n", __func__, mem_size / 1024.0 / 1024.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// pin ctx_eval memory to GPU
|
||||||
|
// this buffer will be used for the intermediate results of the evaluation
|
||||||
|
{
|
||||||
|
const size_t mem_size = ggml_get_mem_size(ctx_eval);
|
||||||
|
|
||||||
|
ctx->buffer_eval = [ctx->device newBufferWithLength:mem_size options:MTLResourceStorageModePrivate];
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: allocated eval buffer, size = %8.2f MB\n", __func__, mem_size / 1024.0 / 1024.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// allocate buffer for result extraction
|
||||||
|
{
|
||||||
|
const size_t mem_size = ggml_nbytes(gf->nodes[gf->n_nodes - 1]);
|
||||||
|
|
||||||
|
ctx->out = [ctx->device newBufferWithLength:mem_size options:MTLResourceStorageModeShared];
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: allocated out buffer, size = %8.2f MB\n", __func__, mem_size / 1024.0 / 1024.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_mtl_free(struct ggml_mtl_context * ctx) {
|
||||||
|
fprintf(stderr, "%s: deallocating\n", __func__);
|
||||||
|
|
||||||
|
free(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// get data / eval buffer + offset
|
||||||
|
id<MTLBuffer> llama_mtl_get_buffer(struct ggml_mtl_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
||||||
|
const int64_t offs_data = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_data);
|
||||||
|
const int64_t offs_eval = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_eval);
|
||||||
|
|
||||||
|
const bool is_data = (offs_eval < 0) || (offs_data >= 0 && offs_data < offs_eval);
|
||||||
|
|
||||||
|
const size_t t_size = ggml_nbytes(t);
|
||||||
|
const size_t t_offs = is_data ? offs_data : offs_eval;
|
||||||
|
|
||||||
|
id<MTLBuffer> result;
|
||||||
|
|
||||||
|
if (is_data) {
|
||||||
|
fprintf(stderr, "%s: data tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
|
||||||
|
result = ctx->buffer_data;
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: eval tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
|
||||||
|
result = ctx->buffer_eval;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result == nil) {
|
||||||
|
fprintf(stderr, "%s: error: buffer is nil\n", __func__);
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (offs != nil) {
|
||||||
|
*offs = t_offs;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_mtl_eval(
|
||||||
|
struct ggml_mtl_context * ctx,
|
||||||
|
struct ggml_cgraph * gf) {
|
||||||
|
fprintf(stderr, "%s: evaluating\n", __func__);
|
||||||
|
|
||||||
|
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer];
|
||||||
|
id<MTLComputeCommandEncoder> encoder = nil;
|
||||||
|
|
||||||
|
size_t offs_src0;
|
||||||
|
size_t offs_src1;
|
||||||
|
size_t offs_dst;
|
||||||
|
|
||||||
|
// copy the input data to the GPU
|
||||||
|
{
|
||||||
|
struct ggml_tensor * inp = ggml_graph_get_tensor(gf, "input");
|
||||||
|
|
||||||
|
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, inp, &offs_src0);
|
||||||
|
|
||||||
|
memcpy(id_dst.contents + offs_src0, inp->data, ggml_nbytes(inp));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < gf->n_nodes; ++i) {
|
||||||
|
fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||||
|
|
||||||
|
switch (gf->nodes[i]->op) {
|
||||||
|
case GGML_OP_ADD:
|
||||||
|
{
|
||||||
|
if (encoder == nil) {
|
||||||
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||||
|
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
||||||
|
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_add];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
|
||||||
|
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_RELU:
|
||||||
|
{
|
||||||
|
if (encoder == nil) {
|
||||||
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||||
|
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_relu];
|
||||||
|
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_SOFT_MAX:
|
||||||
|
{
|
||||||
|
if (encoder == nil) {
|
||||||
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||||
|
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_soft_max];
|
||||||
|
[encoder setBuffer:id_src offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
{
|
||||||
|
if (encoder != nil) {
|
||||||
|
[encoder endEncoding];
|
||||||
|
encoder = nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
// use MPSMatrixMultiplication
|
||||||
|
id<MTLBuffer> id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0);
|
||||||
|
id<MTLBuffer> id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1);
|
||||||
|
id<MTLBuffer> id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst);
|
||||||
|
|
||||||
|
const int64_t ncols0 = gf->nodes[i]->src0->ne[0];
|
||||||
|
const int64_t nrows0 = gf->nodes[i]->src0->ne[1];
|
||||||
|
|
||||||
|
const int64_t ncols1 = gf->nodes[i]->src1->ne[0];
|
||||||
|
const int64_t nrows1 = gf->nodes[i]->src1->ne[1];
|
||||||
|
|
||||||
|
const int64_t ncols2 = gf->nodes[i]->ne[0];
|
||||||
|
const int64_t nrows2 = gf->nodes[i]->ne[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ncols0 == ncols1);
|
||||||
|
|
||||||
|
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
|
||||||
|
matrixDescriptorWithRows:nrows0 columns:ncols0 rowBytes:gf->nodes[i]->src0->nb[1] dataType:MPSDataTypeFloat32];
|
||||||
|
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
|
||||||
|
matrixDescriptorWithRows:nrows1 columns:ncols1 rowBytes:gf->nodes[i]->src1->nb[1] dataType:MPSDataTypeFloat32];
|
||||||
|
MPSMatrixDescriptor * desc2 = [MPSMatrixDescriptor
|
||||||
|
matrixDescriptorWithRows:nrows2 columns:ncols2 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
|
||||||
|
|
||||||
|
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0 descriptor:desc0];
|
||||||
|
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1 descriptor:desc1];
|
||||||
|
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst descriptor:desc2];
|
||||||
|
|
||||||
|
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] initWithDevice:ctx->device
|
||||||
|
transposeLeft:false transposeRight:true resultRows:nrows1 resultColumns:nrows0 interiorColumns:ncols0 alpha:1.0 beta:0.0];
|
||||||
|
|
||||||
|
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract results from the GPU
|
||||||
|
{
|
||||||
|
if (encoder != nil) {
|
||||||
|
[encoder endEncoding];
|
||||||
|
encoder = nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * out = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src = llama_mtl_get_buffer(ctx, out, &offs_src0);
|
||||||
|
id<MTLBuffer> id_dst = ctx->out;
|
||||||
|
|
||||||
|
id<MTLBlitCommandEncoder> encoder_blit = [command_buffer blitCommandEncoder];
|
||||||
|
[encoder_blit copyFromBuffer:id_src sourceOffset:offs_src0 toBuffer:id_dst destinationOffset:0 size:ggml_nbytes(out)];
|
||||||
|
[encoder_blit endEncoding];
|
||||||
|
}
|
||||||
|
|
||||||
|
[command_buffer commit];
|
||||||
|
[command_buffer waitUntilCompleted];
|
||||||
|
|
||||||
|
{
|
||||||
|
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
|
||||||
|
fprintf(stderr, "%s: time elapsed = %f\n", __func__, time_elapsed);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
const float * logits = ctx->out.contents;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue