mtl : plug Metal inference into llama.cpp (very quick-n-dirty)

This commit is contained in:
Georgi Gerganov 2023-06-02 21:52:11 +03:00
parent 640a889632
commit 2f4e9d19cc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 186 additions and 61 deletions

View file

@ -384,11 +384,22 @@ endif()
add_library(llama add_library(llama
llama.cpp llama.cpp
llama.h llama.h
llama-util.h) llama-util.h
examples/mtl/mtl.h # TODO: METAL TMP
examples/mtl/mtl.m # TODO: METAL TMP
)
target_include_directories(llama PUBLIC .) target_include_directories(llama PUBLIC .)
target_compile_features(llama PUBLIC cxx_std_11) # don't bump target_compile_features(llama PUBLIC cxx_std_11) # don't bump
target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS}) target_link_libraries(llama PRIVATE
ggml
${LLAMA_EXTRA_LIBS}
${FOUNDATION_LIBRARY} # TODO: METAL TMP
${METAL_FRAMEWORK} # TODO: METAL TMP
${METALKIT_FRAMEWORK} # TODO: METAL TMP
${METALPERFORMANCE_FRAMEWORK} # TODO: METAL TMP
)
target_compile_definitions(llama PRIVATE LLAMA_MTL_NDEBUG) # TODO: METAL TMP
if (BUILD_SHARED_LIBS) if (BUILD_SHARED_LIBS)
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)

View file

@ -301,6 +301,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.mem_test = true; params.mem_test = true;
} else if (arg == "--export") { } else if (arg == "--export") {
params.export_cgraph = true; params.export_cgraph = true;
} else if (arg == "--import") {
params.import_cgraph = true;
} else if (arg == "--verbose-prompt") { } else if (arg == "--verbose-prompt") {
params.verbose_prompt = true; params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r" || arg == "--reverse-prompt") {
@ -441,6 +443,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
#endif #endif
fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n"); fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n");
fprintf(stderr, " --import import a computation graph from 'llama.ggml'\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
@ -490,6 +493,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
lparams.use_mlock = params.use_mlock; lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity; lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding; lparams.embedding = params.embedding;
lparams.cgraph = params.import_cgraph;
llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);

View file

@ -72,6 +72,7 @@ struct gpt_params {
bool use_mlock = false; // use mlock to keep model in memory bool use_mlock = false; // use mlock to keep model in memory
bool mem_test = false; // compute maximum memory usage bool mem_test = false; // compute maximum memory usage
bool export_cgraph = false; // export the computation graph bool export_cgraph = false; // export the computation graph
bool import_cgraph = false; // import a computation graph
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
}; };

View file

@ -25,6 +25,8 @@ int llama_mtl_eval(
int n_tokens, int n_tokens,
int n_past); int n_past);
float * llama_mtl_get_logits(struct ggml_mtl_context * ctx);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View file

@ -6,11 +6,19 @@
#import <Metal/Metal.h> #import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h> #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#ifdef LLAMA_MTL_NDEBUG
#define mtl_printf(...)
#else
#define mtl_printf(...) fprintf(stderr, __VA_ARGS__)
#endif
struct ggml_mtl_context { struct ggml_mtl_context {
struct ggml_context * ctx_data; struct ggml_context * ctx_data;
struct ggml_context * ctx_eval; struct ggml_context * ctx_eval;
struct ggml_context * ctx_work; struct ggml_context * ctx_work;
float * logits;
id<MTLDevice> device; id<MTLDevice> device;
id<MTLCommandQueue> queue; id<MTLCommandQueue> queue;
id<MTLLibrary> library; id<MTLLibrary> library;
@ -274,7 +282,44 @@ int llama_mtl_eval(
const int * tokens, const int * tokens,
int n_tokens, int n_tokens,
int n_past) { int n_past) {
fprintf(stderr, "%s: evaluating, n_tokens = %d, n_past = %d\n", __func__, n_tokens, n_past); mtl_printf("%s: evaluating, n_tokens = %d, n_past = %d\n", __func__, n_tokens, n_past);
// adjust dynamic shapes
// TODO: wrong ...
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "embd");
// t->ne[0] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Qpre");
// t->src0->ne[2] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Kpre");
// t->src0->ne[2] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Vcur");
// t->ne[0] = n_tokens;
//}
//{
// struct ggml_tensor * k = ggml_graph_get_tensor(gf, "k");
// struct ggml_tensor * v = ggml_graph_get_tensor(gf, "v");
// k->ne[0] = n_tokens*v->ne[1];
// v->ne[0] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Q");
// t->ne[1] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "K");
// t->ne[1] = n_past + n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "KQV_merged_contiguous");
// t->src1->ne[1] = n_tokens;
//}
struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd"); struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd");
memcpy(input->data, tokens, n_tokens * sizeof(int)); memcpy(input->data, tokens, n_tokens * sizeof(int));
@ -296,7 +341,7 @@ int llama_mtl_eval(
} }
for (int i = 0; i < gf->n_nodes; ++i) { for (int i = 0; i < gf->n_nodes; ++i) {
//fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); //mtl_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src0; struct ggml_tensor * src0 = gf->nodes[i]->src0;
struct ggml_tensor * src1 = gf->nodes[i]->src1; struct ggml_tensor * src1 = gf->nodes[i]->src1;
@ -340,7 +385,21 @@ int llama_mtl_eval(
id<MTLBuffer> id_src1 = src1 ? llama_mtl_get_buffer(ctx, src1, &offs_src1) : nil; id<MTLBuffer> id_src1 = src1 ? llama_mtl_get_buffer(ctx, src1, &offs_src1) : nil;
id<MTLBuffer> id_dst = dst ? llama_mtl_get_buffer(ctx, dst, &offs_dst) : nil; id<MTLBuffer> id_dst = dst ? llama_mtl_get_buffer(ctx, dst, &offs_dst) : nil;
switch (gf->nodes[i]->op) { //mtl_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
//if (src0) {
// mtl_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
// ggml_is_contiguous(src0), src0->name);
//}
//if (src1) {
// mtl_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name);
//}
//if (dst) {
// mtl_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name);
//}
switch (dst->op) {
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
@ -359,7 +418,7 @@ int llama_mtl_eval(
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
const int64_t n = ggml_nelements(gf->nodes[i]); const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -369,7 +428,7 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
if (ggml_nelements(gf->nodes[i]->src1) == ne10) { if (ggml_nelements(src1) == ne10) {
// src1 is a row // src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row]; [encoder setComputePipelineState:ctx->pipeline_mul_row];
} else { } else {
@ -380,7 +439,7 @@ int llama_mtl_eval(
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
const int64_t n = ggml_nelements(gf->nodes[i]); const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -390,14 +449,14 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
const float scale = *(const float *) gf->nodes[i]->src1->data; const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale]; [encoder setComputePipelineState:ctx->pipeline_scale];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
const int64_t n = ggml_nelements(gf->nodes[i]); const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -411,7 +470,7 @@ int llama_mtl_eval(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(gf->nodes[i]); const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -425,7 +484,7 @@ int llama_mtl_eval(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(gf->nodes[i]); const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -464,16 +523,11 @@ int llama_mtl_eval(
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
//fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src0t), ne00, ne01, ne02, ggml_is_contiguous(gf->nodes[i]->src0));
//fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src1t), ne10, ne11, ne12, ggml_is_contiguous(gf->nodes[i]->src1));
//fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
//fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne02 == ne12);
if (ggml_is_contiguous(gf->nodes[i]->src0) && if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(gf->nodes[i]->src1) && ggml_is_contiguous(src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) { if (encoder != nil) {
@ -486,13 +540,13 @@ int llama_mtl_eval(
// for F32 x F32 we use MPS // for F32 x F32 we use MPS
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:gf->nodes[i]->src0->nb[1] dataType:src0dt]; matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:gf->nodes[i]->src1->nb[1] dataType:src1dt]; matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32]; matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
initWithDevice:ctx->device transposeLeft:false transposeRight:true initWithDevice:ctx->device transposeLeft:false transposeRight:true
@ -573,22 +627,22 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
switch (gf->nodes[i]->src0->type) { switch (src0->type) {
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
default: { default: {
// not implemented // not implemented
fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), ggml_type_name(gf->nodes[i]->src0->type)); fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(dst->op), ggml_type_name(src0->type));
} }
} }
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&(gf->nodes[i]->src0->ne[0]) length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(gf->nodes[i]->src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(gf->nodes[i]->nb[1]) length:sizeof(uint64_t) atIndex:5]; [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(gf->nodes[i]->src1); const int64_t n = ggml_nelements(src1);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -610,7 +664,7 @@ int llama_mtl_eval(
[encoder setBytes:&eps length:sizeof( float) atIndex:4]; [encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0); const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
@ -620,12 +674,12 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1]; const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2]; const int mode = ((int32_t *) src1->data)[2];
//fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); //mtl_printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
//fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); //mtl_printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
//fprintf(stderr, "rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode); //mtl_printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
[encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -660,11 +714,11 @@ int llama_mtl_eval(
const int nth = 32; const int nth = 32;
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); //mtl_printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03); //mtl_printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); //mtl_printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3); //mtl_printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
//fprintf(stderr, "cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt)); //mtl_printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
switch (src0t) { switch (src0t) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
@ -700,7 +754,7 @@ int llama_mtl_eval(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
default: default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false); GGML_ASSERT(false);
return -1; return -1;
} }
@ -708,7 +762,7 @@ int llama_mtl_eval(
// extract results from the GPU // extract results from the GPU
{ {
fprintf(stderr, "%s: extract results from the GPU\n", __func__); mtl_printf("%s: extract results from the GPU\n", __func__);
if (encoder != nil) { if (encoder != nil) {
[encoder endEncoding]; [encoder endEncoding];
@ -730,18 +784,19 @@ int llama_mtl_eval(
{ {
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime]; const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0); mtl_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
} }
// TODO ctx->logits = ctx->out.contents;
const float * logits = ctx->out.contents;
const float * logits = ctx->logits;
#if 1 #if 1
printf("logits: "); mtl_printf("logits: ");
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
printf("%8.4f ", logits[i]); mtl_printf("%8.4f ", logits[i]);
} }
printf("\n"); mtl_printf("\n");
double sum = 0.0; double sum = 0.0;
int imax = 0; int imax = 0;
double vmax = -INFINITY; double vmax = -INFINITY;
@ -752,7 +807,7 @@ int llama_mtl_eval(
imax = i; imax = i;
} }
} }
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax); mtl_printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
#endif #endif
//{ //{
@ -801,3 +856,7 @@ int llama_mtl_eval(
return 0; return 0;
} }
float * llama_mtl_get_logits(struct ggml_mtl_context * ctx) {
return ctx->logits;
}

View file

@ -9,6 +9,9 @@
#include "llama-util.h" #include "llama-util.h"
#include "llama.h" #include "llama.h"
// METAL
#include "examples/mtl/mtl.h"
#include "ggml.h" #include "ggml.h"
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h" #include "ggml-cuda.h"
@ -238,6 +241,10 @@ struct llama_context {
llama_ctx_buffer buf_compute; llama_ctx_buffer buf_compute;
llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
// METAL
ggml_mtl_context * mtl_ctx = NULL;
ggml_cgraph mtl_gf;
int buf_last = 0; int buf_last = 0;
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
@ -836,6 +843,7 @@ struct llama_context_params llama_context_default_params() {
/*.use_mmap =*/ true, /*.use_mmap =*/ true,
/*.use_mlock =*/ false, /*.use_mlock =*/ false,
/*.embedding =*/ false, /*.embedding =*/ false,
/*.cgraph =*/ false,
/*.progress_callback =*/ nullptr, /*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr, /*.progress_callback_user_data =*/ nullptr,
}; };
@ -1270,8 +1278,14 @@ static bool llama_eval_internal(
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0); //struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
// compute Q and K and RoPE them // compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); struct ggml_tensor * Qpre = ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N);
struct ggml_tensor * Kpre = ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N);
ggml_set_name(Qpre, "Qpre");
ggml_set_name(Kpre, "Kpre");
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, Qpre, n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, Kpre, n_past, n_rot, 0);
ggml_set_name(Qcur, "Qcur"); ggml_set_name(Qcur, "Qcur");
ggml_set_name(Kcur, "Kcur"); ggml_set_name(Kcur, "Kcur");
@ -1279,22 +1293,19 @@ static bool llama_eval_internal(
{ {
// compute the transposed [N, n_embd] V matrix // compute the transposed [N, n_embd] V matrix
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N)); struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
( n_ctx)*ggml_element_size(kv_self.v), ( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
//struct ggml_tensor * t = ggml_cpy(ctx0, Vcur, v); ggml_set_name(k, "k");
//// TODO: TMP !!!! ggml_set_name(v, "v");
//if (il == 0) {
// ggml_set_name(t, "mtl-check");
//}
// important: storing RoPE-ed version of K in the KV cache! // important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
//ggml_build_forward_expand(&gf, t);
} }
struct ggml_tensor * Q = struct ggml_tensor * Q =
@ -2391,9 +2402,25 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
// METAL
if (params.cgraph) {
params.vocab_only = true;
// load the compute graph
struct ggml_context * ctx_data = NULL;
struct ggml_context * ctx_eval = NULL;
struct ggml_cgraph gf = ggml_graph_import("llama.ggml", &ctx_data, &ctx_eval);
gf.n_threads = 1;
// this allocates all Metal resources and memory buffers
ctx->mtl_ctx = llama_mtl_init(ctx_data, ctx_eval, NULL, &gf);
ctx->mtl_gf = gf;
}
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type, if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
params.use_mmap, params.use_mlock, params.vocab_only, params.use_mmap, params.use_mlock, params.vocab_only,
params.progress_callback, params.progress_callback_user_data)) { params.progress_callback, params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__); fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;
@ -2411,7 +2438,11 @@ struct llama_context * llama_init_from_file(
const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v); const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
} }
}
// METAL
// TODO: changed the behavior here for vocab_only -- reconsider implications later
{
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
// resized during inference // resized during inference
@ -3046,9 +3077,25 @@ int llama_eval(
int n_tokens, int n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads) {
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) { // METAL
fprintf(stderr, "%s: failed to eval\n", __func__); if (ctx->mtl_ctx) {
return 1; llama_mtl_eval(ctx->mtl_ctx, &ctx->mtl_gf, tokens, n_tokens, n_past);
const float * logits = llama_mtl_get_logits(ctx->mtl_ctx);
// extract logits
{
const int n_vocab = ctx->model.hparams.n_vocab;
auto & logits_out = ctx->logits;
logits_out.resize(n_vocab);
memcpy(logits_out.data(), logits, sizeof(float)*n_vocab);
}
} else {
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return 1;
}
} }
// get a more accurate load time, upon first eval // get a more accurate load time, upon first eval

View file

@ -75,6 +75,7 @@ extern "C" {
bool use_mmap; // use mmap if possible bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only bool embedding; // embedding mode only
bool cgraph; // try to load computation graph from "llama.ggml" (METAL)
// called with a progress value between 0 and 1, pass NULL to disable // called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback; llama_progress_callback progress_callback;