mtl : plug Metal inference into llama.cpp (very quick-n-dirty)
This commit is contained in:
parent
640a889632
commit
2f4e9d19cc
7 changed files with 186 additions and 61 deletions
|
@ -384,11 +384,22 @@ endif()
|
|||
add_library(llama
|
||||
llama.cpp
|
||||
llama.h
|
||||
llama-util.h)
|
||||
llama-util.h
|
||||
examples/mtl/mtl.h # TODO: METAL TMP
|
||||
examples/mtl/mtl.m # TODO: METAL TMP
|
||||
)
|
||||
|
||||
target_include_directories(llama PUBLIC .)
|
||||
target_compile_features(llama PUBLIC cxx_std_11) # don't bump
|
||||
target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS})
|
||||
target_link_libraries(llama PRIVATE
|
||||
ggml
|
||||
${LLAMA_EXTRA_LIBS}
|
||||
${FOUNDATION_LIBRARY} # TODO: METAL TMP
|
||||
${METAL_FRAMEWORK} # TODO: METAL TMP
|
||||
${METALKIT_FRAMEWORK} # TODO: METAL TMP
|
||||
${METALPERFORMANCE_FRAMEWORK} # TODO: METAL TMP
|
||||
)
|
||||
target_compile_definitions(llama PRIVATE LLAMA_MTL_NDEBUG) # TODO: METAL TMP
|
||||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
|
|
@ -301,6 +301,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
params.mem_test = true;
|
||||
} else if (arg == "--export") {
|
||||
params.export_cgraph = true;
|
||||
} else if (arg == "--import") {
|
||||
params.import_cgraph = true;
|
||||
} else if (arg == "--verbose-prompt") {
|
||||
params.verbose_prompt = true;
|
||||
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
||||
|
@ -441,6 +443,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
#endif
|
||||
fprintf(stderr, " --mtest compute maximum memory usage\n");
|
||||
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n");
|
||||
fprintf(stderr, " --import import a computation graph from 'llama.ggml'\n");
|
||||
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
|
||||
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||
|
@ -490,6 +493,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
|
|||
lparams.use_mlock = params.use_mlock;
|
||||
lparams.logits_all = params.perplexity;
|
||||
lparams.embedding = params.embedding;
|
||||
lparams.cgraph = params.import_cgraph;
|
||||
|
||||
llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
|
||||
|
||||
|
|
|
@ -72,6 +72,7 @@ struct gpt_params {
|
|||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool mem_test = false; // compute maximum memory usage
|
||||
bool export_cgraph = false; // export the computation graph
|
||||
bool import_cgraph = false; // import a computation graph
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
};
|
||||
|
||||
|
|
|
@ -25,6 +25,8 @@ int llama_mtl_eval(
|
|||
int n_tokens,
|
||||
int n_past);
|
||||
|
||||
float * llama_mtl_get_logits(struct ggml_mtl_context * ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -6,11 +6,19 @@
|
|||
#import <Metal/Metal.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
#ifdef LLAMA_MTL_NDEBUG
|
||||
#define mtl_printf(...)
|
||||
#else
|
||||
#define mtl_printf(...) fprintf(stderr, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
struct ggml_mtl_context {
|
||||
struct ggml_context * ctx_data;
|
||||
struct ggml_context * ctx_eval;
|
||||
struct ggml_context * ctx_work;
|
||||
|
||||
float * logits;
|
||||
|
||||
id<MTLDevice> device;
|
||||
id<MTLCommandQueue> queue;
|
||||
id<MTLLibrary> library;
|
||||
|
@ -274,7 +282,44 @@ int llama_mtl_eval(
|
|||
const int * tokens,
|
||||
int n_tokens,
|
||||
int n_past) {
|
||||
fprintf(stderr, "%s: evaluating, n_tokens = %d, n_past = %d\n", __func__, n_tokens, n_past);
|
||||
mtl_printf("%s: evaluating, n_tokens = %d, n_past = %d\n", __func__, n_tokens, n_past);
|
||||
|
||||
// adjust dynamic shapes
|
||||
// TODO: wrong ...
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "embd");
|
||||
// t->ne[0] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Qpre");
|
||||
// t->src0->ne[2] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Kpre");
|
||||
// t->src0->ne[2] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Vcur");
|
||||
// t->ne[0] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * k = ggml_graph_get_tensor(gf, "k");
|
||||
// struct ggml_tensor * v = ggml_graph_get_tensor(gf, "v");
|
||||
// k->ne[0] = n_tokens*v->ne[1];
|
||||
// v->ne[0] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Q");
|
||||
// t->ne[1] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "K");
|
||||
// t->ne[1] = n_past + n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "KQV_merged_contiguous");
|
||||
// t->src1->ne[1] = n_tokens;
|
||||
//}
|
||||
|
||||
struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd");
|
||||
memcpy(input->data, tokens, n_tokens * sizeof(int));
|
||||
|
@ -296,7 +341,7 @@ int llama_mtl_eval(
|
|||
}
|
||||
|
||||
for (int i = 0; i < gf->n_nodes; ++i) {
|
||||
//fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
//mtl_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
|
||||
struct ggml_tensor * src0 = gf->nodes[i]->src0;
|
||||
struct ggml_tensor * src1 = gf->nodes[i]->src1;
|
||||
|
@ -340,7 +385,21 @@ int llama_mtl_eval(
|
|||
id<MTLBuffer> id_src1 = src1 ? llama_mtl_get_buffer(ctx, src1, &offs_src1) : nil;
|
||||
id<MTLBuffer> id_dst = dst ? llama_mtl_get_buffer(ctx, dst, &offs_dst) : nil;
|
||||
|
||||
switch (gf->nodes[i]->op) {
|
||||
//mtl_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
||||
//if (src0) {
|
||||
// mtl_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
|
||||
// ggml_is_contiguous(src0), src0->name);
|
||||
//}
|
||||
//if (src1) {
|
||||
// mtl_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
||||
// ggml_is_contiguous(src1), src1->name);
|
||||
//}
|
||||
//if (dst) {
|
||||
// mtl_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
||||
// dst->name);
|
||||
//}
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
|
@ -359,7 +418,7 @@ int llama_mtl_eval(
|
|||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
|
@ -369,7 +428,7 @@ int llama_mtl_eval(
|
|||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
if (ggml_nelements(gf->nodes[i]->src1) == ne10) {
|
||||
if (ggml_nelements(src1) == ne10) {
|
||||
// src1 is a row
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_row];
|
||||
} else {
|
||||
|
@ -380,7 +439,7 @@ int llama_mtl_eval(
|
|||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
|
||||
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
|
@ -390,14 +449,14 @@ int llama_mtl_eval(
|
|||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
const float scale = *(const float *) gf->nodes[i]->src1->data;
|
||||
const float scale = *(const float *) src1->data;
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_scale];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
|
||||
|
||||
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
|
@ -411,7 +470,7 @@ int llama_mtl_eval(
|
|||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
|
@ -425,7 +484,7 @@ int llama_mtl_eval(
|
|||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
|
||||
const int64_t n = ggml_nelements(gf->nodes[i]);
|
||||
const int64_t n = ggml_nelements(dst);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
|
@ -464,16 +523,11 @@ int llama_mtl_eval(
|
|||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
//fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src0t), ne00, ne01, ne02, ggml_is_contiguous(gf->nodes[i]->src0));
|
||||
//fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src1t), ne10, ne11, ne12, ggml_is_contiguous(gf->nodes[i]->src1));
|
||||
//fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2);
|
||||
//fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt));
|
||||
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
|
||||
if (ggml_is_contiguous(gf->nodes[i]->src0) &&
|
||||
ggml_is_contiguous(gf->nodes[i]->src1) &&
|
||||
if (ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
|
||||
|
||||
if (encoder != nil) {
|
||||
|
@ -486,13 +540,13 @@ int llama_mtl_eval(
|
|||
|
||||
// for F32 x F32 we use MPS
|
||||
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
|
||||
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:gf->nodes[i]->src0->nb[1] dataType:src0dt];
|
||||
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
|
||||
|
||||
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
|
||||
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:gf->nodes[i]->src1->nb[1] dataType:src1dt];
|
||||
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
|
||||
|
||||
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
|
||||
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32];
|
||||
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
|
||||
|
||||
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
|
||||
initWithDevice:ctx->device transposeLeft:false transposeRight:true
|
||||
|
@ -573,22 +627,22 @@ int llama_mtl_eval(
|
|||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
switch (gf->nodes[i]->src0->type) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||
default: {
|
||||
// not implemented
|
||||
fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), ggml_type_name(gf->nodes[i]->src0->type));
|
||||
fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(dst->op), ggml_type_name(src0->type));
|
||||
}
|
||||
}
|
||||
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&(gf->nodes[i]->src0->ne[0]) length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&(gf->nodes[i]->src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&(gf->nodes[i]->nb[1]) length:sizeof(uint64_t) atIndex:5];
|
||||
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
|
||||
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
|
||||
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
|
||||
|
||||
const int64_t n = ggml_nelements(gf->nodes[i]->src1);
|
||||
const int64_t n = ggml_nelements(src1);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
|
@ -610,7 +664,7 @@ int llama_mtl_eval(
|
|||
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
|
||||
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
|
||||
|
||||
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
|
@ -620,12 +674,12 @@ int llama_mtl_eval(
|
|||
encoder = [command_buffer computeCommandEncoder];
|
||||
}
|
||||
|
||||
const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1];
|
||||
const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2];
|
||||
const int n_dims = ((int32_t *) src1->data)[1];
|
||||
const int mode = ((int32_t *) src1->data)[2];
|
||||
|
||||
//fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
||||
//fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
||||
//fprintf(stderr, "rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
|
||||
//mtl_printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
||||
//mtl_printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
||||
//mtl_printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode);
|
||||
|
||||
[encoder setComputePipelineState:ctx->pipeline_rope];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
|
@ -660,11 +714,11 @@ int llama_mtl_eval(
|
|||
|
||||
const int nth = 32;
|
||||
|
||||
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
||||
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
|
||||
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
||||
//fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
|
||||
//fprintf(stderr, "cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
|
||||
//mtl_printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03);
|
||||
//mtl_printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03);
|
||||
//mtl_printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3);
|
||||
//mtl_printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3);
|
||||
//mtl_printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt));
|
||||
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
|
@ -700,7 +754,7 @@ int llama_mtl_eval(
|
|||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
default:
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||
GGML_ASSERT(false);
|
||||
return -1;
|
||||
}
|
||||
|
@ -708,7 +762,7 @@ int llama_mtl_eval(
|
|||
|
||||
// extract results from the GPU
|
||||
{
|
||||
fprintf(stderr, "%s: extract results from the GPU\n", __func__);
|
||||
mtl_printf("%s: extract results from the GPU\n", __func__);
|
||||
|
||||
if (encoder != nil) {
|
||||
[encoder endEncoding];
|
||||
|
@ -730,18 +784,19 @@ int llama_mtl_eval(
|
|||
|
||||
{
|
||||
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
|
||||
printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
|
||||
mtl_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
|
||||
}
|
||||
|
||||
// TODO
|
||||
const float * logits = ctx->out.contents;
|
||||
ctx->logits = ctx->out.contents;
|
||||
|
||||
const float * logits = ctx->logits;
|
||||
|
||||
#if 1
|
||||
printf("logits: ");
|
||||
mtl_printf("logits: ");
|
||||
for (int i = 0; i < 100; i++) {
|
||||
printf("%8.4f ", logits[i]);
|
||||
mtl_printf("%8.4f ", logits[i]);
|
||||
}
|
||||
printf("\n");
|
||||
mtl_printf("\n");
|
||||
double sum = 0.0;
|
||||
int imax = 0;
|
||||
double vmax = -INFINITY;
|
||||
|
@ -752,7 +807,7 @@ int llama_mtl_eval(
|
|||
imax = i;
|
||||
}
|
||||
}
|
||||
printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
|
||||
mtl_printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax);
|
||||
#endif
|
||||
|
||||
//{
|
||||
|
@ -801,3 +856,7 @@ int llama_mtl_eval(
|
|||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float * llama_mtl_get_logits(struct ggml_mtl_context * ctx) {
|
||||
return ctx->logits;
|
||||
}
|
||||
|
|
73
llama.cpp
73
llama.cpp
|
@ -9,6 +9,9 @@
|
|||
#include "llama-util.h"
|
||||
#include "llama.h"
|
||||
|
||||
// METAL
|
||||
#include "examples/mtl/mtl.h"
|
||||
|
||||
#include "ggml.h"
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
|
@ -238,6 +241,10 @@ struct llama_context {
|
|||
llama_ctx_buffer buf_compute;
|
||||
llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
|
||||
|
||||
// METAL
|
||||
ggml_mtl_context * mtl_ctx = NULL;
|
||||
ggml_cgraph mtl_gf;
|
||||
|
||||
int buf_last = 0;
|
||||
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
||||
|
||||
|
@ -836,6 +843,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.use_mmap =*/ true,
|
||||
/*.use_mlock =*/ false,
|
||||
/*.embedding =*/ false,
|
||||
/*.cgraph =*/ false,
|
||||
/*.progress_callback =*/ nullptr,
|
||||
/*.progress_callback_user_data =*/ nullptr,
|
||||
};
|
||||
|
@ -1270,8 +1278,14 @@ static bool llama_eval_internal(
|
|||
//struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
|
||||
struct ggml_tensor * Qpre = ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N);
|
||||
struct ggml_tensor * Kpre = ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N);
|
||||
ggml_set_name(Qpre, "Qpre");
|
||||
ggml_set_name(Kpre, "Kpre");
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, Qpre, n_past, n_rot, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, Kpre, n_past, n_rot, 0);
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
||||
|
@ -1279,22 +1293,19 @@ static bool llama_eval_internal(
|
|||
{
|
||||
// compute the transposed [N, n_embd] V matrix
|
||||
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
|
||||
ggml_set_name(Vcur, "Vcur");
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||
|
||||
//struct ggml_tensor * t = ggml_cpy(ctx0, Vcur, v);
|
||||
//// TODO: TMP !!!!
|
||||
//if (il == 0) {
|
||||
// ggml_set_name(t, "mtl-check");
|
||||
//}
|
||||
ggml_set_name(k, "k");
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
//ggml_build_forward_expand(&gf, t);
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
|
@ -2391,9 +2402,25 @@ struct llama_context * llama_init_from_file(
|
|||
|
||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
// METAL
|
||||
if (params.cgraph) {
|
||||
params.vocab_only = true;
|
||||
|
||||
// load the compute graph
|
||||
struct ggml_context * ctx_data = NULL;
|
||||
struct ggml_context * ctx_eval = NULL;
|
||||
|
||||
struct ggml_cgraph gf = ggml_graph_import("llama.ggml", &ctx_data, &ctx_eval);
|
||||
gf.n_threads = 1;
|
||||
|
||||
// this allocates all Metal resources and memory buffers
|
||||
ctx->mtl_ctx = llama_mtl_init(ctx_data, ctx_eval, NULL, &gf);
|
||||
ctx->mtl_gf = gf;
|
||||
}
|
||||
|
||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
|
||||
params.use_mmap, params.use_mlock, params.vocab_only,
|
||||
params.progress_callback, params.progress_callback_user_data)) {
|
||||
params.use_mmap, params.use_mlock, params.vocab_only,
|
||||
params.progress_callback, params.progress_callback_user_data)) {
|
||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
|
@ -2411,7 +2438,11 @@ struct llama_context * llama_init_from_file(
|
|||
const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
|
||||
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
|
||||
// METAL
|
||||
// TODO: changed the behavior here for vocab_only -- reconsider implications later
|
||||
{
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
// resized during inference
|
||||
|
@ -3046,9 +3077,25 @@ int llama_eval(
|
|||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) {
|
||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
// METAL
|
||||
if (ctx->mtl_ctx) {
|
||||
llama_mtl_eval(ctx->mtl_ctx, &ctx->mtl_gf, tokens, n_tokens, n_past);
|
||||
|
||||
const float * logits = llama_mtl_get_logits(ctx->mtl_ctx);
|
||||
|
||||
// extract logits
|
||||
{
|
||||
const int n_vocab = ctx->model.hparams.n_vocab;
|
||||
auto & logits_out = ctx->logits;
|
||||
|
||||
logits_out.resize(n_vocab);
|
||||
memcpy(logits_out.data(), logits, sizeof(float)*n_vocab);
|
||||
}
|
||||
} else {
|
||||
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) {
|
||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
|
|
1
llama.h
1
llama.h
|
@ -75,6 +75,7 @@ extern "C" {
|
|||
bool use_mmap; // use mmap if possible
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool embedding; // embedding mode only
|
||||
bool cgraph; // try to load computation graph from "llama.ggml" (METAL)
|
||||
|
||||
// called with a progress value between 0 and 1, pass NULL to disable
|
||||
llama_progress_callback progress_callback;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue