mtl : plug Metal inference into llama.cpp (very quick-n-dirty)

This commit is contained in:
Georgi Gerganov 2023-06-02 21:52:11 +03:00
parent 640a889632
commit 2f4e9d19cc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
7 changed files with 186 additions and 61 deletions

View file

@ -384,11 +384,22 @@ endif()
add_library(llama
llama.cpp
llama.h
llama-util.h)
llama-util.h
examples/mtl/mtl.h # TODO: METAL TMP
examples/mtl/mtl.m # TODO: METAL TMP
)
target_include_directories(llama PUBLIC .)
target_compile_features(llama PUBLIC cxx_std_11) # don't bump
target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS})
target_link_libraries(llama PRIVATE
ggml
${LLAMA_EXTRA_LIBS}
${FOUNDATION_LIBRARY} # TODO: METAL TMP
${METAL_FRAMEWORK} # TODO: METAL TMP
${METALKIT_FRAMEWORK} # TODO: METAL TMP
${METALPERFORMANCE_FRAMEWORK} # TODO: METAL TMP
)
target_compile_definitions(llama PRIVATE LLAMA_MTL_NDEBUG) # TODO: METAL TMP
if (BUILD_SHARED_LIBS)
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)

View file

@ -301,6 +301,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.mem_test = true;
} else if (arg == "--export") {
params.export_cgraph = true;
} else if (arg == "--import") {
params.import_cgraph = true;
} else if (arg == "--verbose-prompt") {
params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
@ -441,6 +443,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
#endif
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n");
fprintf(stderr, " --import import a computation graph from 'llama.ggml'\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
@ -490,6 +493,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
lparams.cgraph = params.import_cgraph;
llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);

View file

@ -72,6 +72,7 @@ struct gpt_params {
bool use_mlock = false; // use mlock to keep model in memory
bool mem_test = false; // compute maximum memory usage
bool export_cgraph = false; // export the computation graph
bool import_cgraph = false; // import a computation graph
bool verbose_prompt = false; // print prompt tokens before generation
};

View file

@ -25,6 +25,8 @@ int llama_mtl_eval(
int n_tokens,
int n_past);
float * llama_mtl_get_logits(struct ggml_mtl_context * ctx);
#ifdef __cplusplus
}
#endif

View file

@ -6,11 +6,19 @@
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#ifdef LLAMA_MTL_NDEBUG
#define mtl_printf(...)
#else
#define mtl_printf(...) fprintf(stderr, __VA_ARGS__)
#endif
struct ggml_mtl_context {
struct ggml_context * ctx_data;
struct ggml_context * ctx_eval;
struct ggml_context * ctx_work;
float * logits;
id<MTLDevice> device;
id<MTLCommandQueue> queue;
id<MTLLibrary> library;
@ -274,7 +282,44 @@ int llama_mtl_eval(
const int * tokens,
int n_tokens,
int n_past) {
fprintf(stderr, "%s: evaluating, n_tokens = %d, n_past = %d\n", __func__, n_tokens, n_past);
mtl_printf("%s: evaluating, n_tokens = %d, n_past = %d\n", __func__, n_tokens, n_past);
// adjust dynamic shapes
// TODO: wrong ...
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "embd");
// t->ne[0] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Qpre");
// t->src0->ne[2] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Kpre");
// t->src0->ne[2] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Vcur");
// t->ne[0] = n_tokens;
//}
//{
// struct ggml_tensor * k = ggml_graph_get_tensor(gf, "k");
// struct ggml_tensor * v = ggml_graph_get_tensor(gf, "v");
// k->ne[0] = n_tokens*v->ne[1];
// v->ne[0] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Q");
// t->ne[1] = n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "K");
// t->ne[1] = n_past + n_tokens;
//}
//{
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "KQV_merged_contiguous");
// t->src1->ne[1] = n_tokens;
//}
struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd");
memcpy(input->data, tokens, n_tokens * sizeof(int));
@ -296,7 +341,7 @@ int llama_mtl_eval(
}
for (int i = 0; i < gf->n_nodes; ++i) {
//fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
//mtl_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src0;
struct ggml_tensor * src1 = gf->nodes[i]->src1;
@ -340,7 +385,21 @@ int llama_mtl_eval(
id<MTLBuffer> id_src1 = src1 ? llama_mtl_get_buffer(ctx, src1, &offs_src1) : nil;
id<MTLBuffer> id_dst = dst ? llama_mtl_get_buffer(ctx, dst, &offs_dst) : nil;
switch (gf->nodes[i]->op) {
//mtl_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
//if (src0) {
// mtl_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
// ggml_is_contiguous(src0), src0->name);
//}
//if (src1) {
// mtl_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name);
//}
//if (dst) {
// mtl_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name);
//}
switch (dst->op) {
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
@ -359,7 +418,7 @@ int llama_mtl_eval(
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
const int64_t n = ggml_nelements(gf->nodes[i]);
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -369,7 +428,7 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder];
}
if (ggml_nelements(gf->nodes[i]->src1) == ne10) {
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row];
} else {
@ -380,7 +439,7 @@ int llama_mtl_eval(
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
const int64_t n = ggml_nelements(gf->nodes[i]);
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -390,14 +449,14 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder];
}
const float scale = *(const float *) gf->nodes[i]->src1->data;
const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
const int64_t n = ggml_nelements(gf->nodes[i]);
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -411,7 +470,7 @@ int llama_mtl_eval(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(gf->nodes[i]);
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -425,7 +484,7 @@ int llama_mtl_eval(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(gf->nodes[i]);
const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -464,16 +523,11 @@ int llama_mtl_eval(
} break;
case GGML_OP_MUL_MAT:
{
//fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src0t), ne00, ne01, ne02, ggml_is_contiguous(gf->nodes[i]->src0));
//fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src1t), ne10, ne11, ne12, ggml_is_contiguous(gf->nodes[i]->src1));
//fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
//fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
GGML_ASSERT(ne00 == ne10);
GGML_ASSERT(ne02 == ne12);
if (ggml_is_contiguous(gf->nodes[i]->src0) &&
ggml_is_contiguous(gf->nodes[i]->src1) &&
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) {
@ -486,13 +540,13 @@ int llama_mtl_eval(
// for F32 x F32 we use MPS
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:gf->nodes[i]->src0->nb[1] dataType:src0dt];
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:gf->nodes[i]->src1->nb[1] dataType:src1dt];
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
initWithDevice:ctx->device transposeLeft:false transposeRight:true
@ -573,22 +627,22 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder];
}
switch (gf->nodes[i]->src0->type) {
switch (src0->type) {
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
default: {
// not implemented
fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), ggml_type_name(gf->nodes[i]->src0->type));
fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(dst->op), ggml_type_name(src0->type));
}
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&(gf->nodes[i]->src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(gf->nodes[i]->src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(gf->nodes[i]->nb[1]) length:sizeof(uint64_t) atIndex:5];
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(gf->nodes[i]->src1);
const int64_t n = ggml_nelements(src1);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
@ -610,7 +664,7 @@ int llama_mtl_eval(
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
@ -620,12 +674,12 @@ int llama_mtl_eval(
encoder = [command_buffer computeCommandEncoder];
}
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
//fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
//fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
//fprintf(stderr, "rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
//mtl_printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
//mtl_printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
//mtl_printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -660,11 +714,11 @@ int llama_mtl_eval(
const int nth = 32;
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
//fprintf(stderr, "cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
//mtl_printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
//mtl_printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
//mtl_printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
//mtl_printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
//mtl_printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
switch (src0t) {
case GGML_TYPE_F32:
@ -700,7 +754,7 @@ int llama_mtl_eval(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
return -1;
}
@ -708,7 +762,7 @@ int llama_mtl_eval(
// extract results from the GPU
{
fprintf(stderr, "%s: extract results from the GPU\n", __func__);
mtl_printf("%s: extract results from the GPU\n", __func__);
if (encoder != nil) {
[encoder endEncoding];
@ -730,18 +784,19 @@ int llama_mtl_eval(
{
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
mtl_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
}
// TODO
const float * logits = ctx->out.contents;
ctx->logits = ctx->out.contents;
const float * logits = ctx->logits;
#if 1
printf("logits: ");
mtl_printf("logits: ");
for (int i = 0; i < 100; i++) {
printf("%8.4f ", logits[i]);
mtl_printf("%8.4f ", logits[i]);
}
printf("\n");
mtl_printf("\n");
double sum = 0.0;
int imax = 0;
double vmax = -INFINITY;
@ -752,7 +807,7 @@ int llama_mtl_eval(
imax = i;
}
}
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
mtl_printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
#endif
//{
@ -801,3 +856,7 @@ int llama_mtl_eval(
return 0;
}
float * llama_mtl_get_logits(struct ggml_mtl_context * ctx) {
return ctx->logits;
}

View file

@ -9,6 +9,9 @@
#include "llama-util.h"
#include "llama.h"
// METAL
#include "examples/mtl/mtl.h"
#include "ggml.h"
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
@ -238,6 +241,10 @@ struct llama_context {
llama_ctx_buffer buf_compute;
llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
// METAL
ggml_mtl_context * mtl_ctx = NULL;
ggml_cgraph mtl_gf;
int buf_last = 0;
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
@ -836,6 +843,7 @@ struct llama_context_params llama_context_default_params() {
/*.use_mmap =*/ true,
/*.use_mlock =*/ false,
/*.embedding =*/ false,
/*.cgraph =*/ false,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
};
@ -1270,8 +1278,14 @@ static bool llama_eval_internal(
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
struct ggml_tensor * Qpre = ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N);
struct ggml_tensor * Kpre = ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N);
ggml_set_name(Qpre, "Qpre");
ggml_set_name(Kpre, "Kpre");
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, Qpre, n_past, n_rot, 0);
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, Kpre, n_past, n_rot, 0);
ggml_set_name(Qcur, "Qcur");
ggml_set_name(Kcur, "Kcur");
@ -1279,22 +1293,19 @@ static bool llama_eval_internal(
{
// compute the transposed [N, n_embd] V matrix
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
//struct ggml_tensor * t = ggml_cpy(ctx0, Vcur, v);
//// TODO: TMP !!!!
//if (il == 0) {
// ggml_set_name(t, "mtl-check");
//}
ggml_set_name(k, "k");
ggml_set_name(v, "v");
// important: storing RoPE-ed version of K in the KV cache!
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
//ggml_build_forward_expand(&gf, t);
}
struct ggml_tensor * Q =
@ -2391,6 +2402,22 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
// METAL
if (params.cgraph) {
params.vocab_only = true;
// load the compute graph
struct ggml_context * ctx_data = NULL;
struct ggml_context * ctx_eval = NULL;
struct ggml_cgraph gf = ggml_graph_import("llama.ggml", &ctx_data, &ctx_eval);
gf.n_threads = 1;
// this allocates all Metal resources and memory buffers
ctx->mtl_ctx = llama_mtl_init(ctx_data, ctx_eval, NULL, &gf);
ctx->mtl_gf = gf;
}
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
params.use_mmap, params.use_mlock, params.vocab_only,
params.progress_callback, params.progress_callback_user_data)) {
@ -2411,7 +2438,11 @@ struct llama_context * llama_init_from_file(
const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
}
// METAL
// TODO: changed the behavior here for vocab_only -- reconsider implications later
{
const auto & hparams = ctx->model.hparams;
// resized during inference
@ -3046,10 +3077,26 @@ int llama_eval(
int n_tokens,
int n_past,
int n_threads) {
// METAL
if (ctx->mtl_ctx) {
llama_mtl_eval(ctx->mtl_ctx, &ctx->mtl_gf, tokens, n_tokens, n_past);
const float * logits = llama_mtl_get_logits(ctx->mtl_ctx);
// extract logits
{
const int n_vocab = ctx->model.hparams.n_vocab;
auto & logits_out = ctx->logits;
logits_out.resize(n_vocab);
memcpy(logits_out.data(), logits, sizeof(float)*n_vocab);
}
} else {
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) {
fprintf(stderr, "%s: failed to eval\n", __func__);
return 1;
}
}
// get a more accurate load time, upon first eval
// TODO: fix this

View file

@ -75,6 +75,7 @@ extern "C" {
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only
bool cgraph; // try to load computation graph from "llama.ggml" (METAL)
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;