mtl : make it work with main example
Lots of hacks but at least now it generates text
This commit is contained in:
parent
2f4e9d19cc
commit
4df2ef3161
4 changed files with 119 additions and 201 deletions
|
@ -24,6 +24,8 @@ int main(int argc, char ** argv) {
|
|||
struct ggml_cgraph gf = ggml_graph_import(fname_cgraph, &ctx_data, &ctx_eval);
|
||||
gf.n_threads = 1;
|
||||
|
||||
int32_t n_vocab = 0;
|
||||
|
||||
{
|
||||
struct ggml_tensor * t_vocab = ggml_graph_get_tensor(&gf, "vocab");
|
||||
if (t_vocab == NULL) {
|
||||
|
@ -33,7 +35,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const char * ptr = (const char *) t_vocab->data;
|
||||
|
||||
int32_t n_vocab = 0;
|
||||
memcpy(&n_vocab, ptr, sizeof(n_vocab)); ptr += sizeof(n_vocab);
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, n_vocab);
|
||||
|
@ -49,20 +50,14 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
// allocate work context
|
||||
static size_t buf_size = gf.work_size; // TODO
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ buf,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx_work = ggml_init(params);
|
||||
|
||||
// this allocates all Metal resources and memory buffers
|
||||
auto * ctx_mtl = llama_mtl_init(ctx_data, ctx_eval, ctx_work, &gf);
|
||||
auto * ctx_mtl = llama_mtl_init(
|
||||
ggml_get_mem_buffer(ctx_data),
|
||||
ggml_get_mem_size (ctx_data),
|
||||
ggml_get_mem_buffer(ctx_eval),
|
||||
ggml_get_mem_size (ctx_eval),
|
||||
NULL, 0, // cache
|
||||
32*n_vocab*sizeof(float));
|
||||
|
||||
// TODO: tmp to match the input used when creating the cgraph
|
||||
{
|
||||
|
@ -90,7 +85,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_mtl_free(ctx_mtl);
|
||||
|
||||
ggml_free(ctx_work);
|
||||
ggml_free(ctx_data);
|
||||
ggml_free(ctx_eval);
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
struct ggml_context;
|
||||
struct ggml_cgraph;
|
||||
|
||||
|
@ -10,10 +12,13 @@ extern "C" {
|
|||
struct ggml_mtl_context;
|
||||
|
||||
struct ggml_mtl_context * llama_mtl_init(
|
||||
struct ggml_context * ctx_data,
|
||||
struct ggml_context * ctx_eval,
|
||||
struct ggml_context * ctx_work,
|
||||
struct ggml_cgraph * gf);
|
||||
void * data_buf,
|
||||
size_t data_size,
|
||||
void * eval_buf,
|
||||
size_t eval_size,
|
||||
void * cach_buf,
|
||||
size_t cach_size,
|
||||
size_t outp_size);
|
||||
|
||||
void llama_mtl_free(struct ggml_mtl_context * ctx);
|
||||
|
||||
|
|
|
@ -11,11 +11,16 @@
|
|||
#else
|
||||
#define mtl_printf(...) fprintf(stderr, __VA_ARGS__)
|
||||
#endif
|
||||
//#define mtl_printf(...)
|
||||
|
||||
struct ggml_mtl_context {
|
||||
struct ggml_context * ctx_data;
|
||||
struct ggml_context * ctx_eval;
|
||||
struct ggml_context * ctx_work;
|
||||
void * data_buf;
|
||||
size_t data_size;
|
||||
void * eval_buf;
|
||||
size_t eval_size;
|
||||
void * cach_buf;
|
||||
size_t cach_size;
|
||||
size_t outp_size;
|
||||
|
||||
float * logits;
|
||||
|
||||
|
@ -25,6 +30,7 @@ struct ggml_mtl_context {
|
|||
|
||||
id<MTLBuffer> buffer_data;
|
||||
id<MTLBuffer> buffer_eval;
|
||||
id<MTLBuffer> buffer_cach;
|
||||
|
||||
id<MTLBuffer> out;
|
||||
|
||||
|
@ -82,17 +88,23 @@ struct ggml_mtl_context {
|
|||
NSString * const msl_library_llama = @"see mtl.metal";
|
||||
|
||||
struct ggml_mtl_context * llama_mtl_init(
|
||||
struct ggml_context * ctx_data,
|
||||
struct ggml_context * ctx_eval,
|
||||
struct ggml_context * ctx_work,
|
||||
struct ggml_cgraph * gf) {
|
||||
void * data_buf,
|
||||
size_t data_size,
|
||||
void * eval_buf,
|
||||
size_t eval_size,
|
||||
void * cach_buf,
|
||||
size_t cach_size,
|
||||
size_t outp_size) {
|
||||
fprintf(stderr, "%s: allocating\n", __func__);
|
||||
|
||||
struct ggml_mtl_context * ctx = malloc(sizeof(struct ggml_mtl_context));
|
||||
|
||||
ctx->ctx_data = ctx_data;
|
||||
ctx->ctx_eval = ctx_eval;
|
||||
ctx->ctx_work = ctx_work;
|
||||
ctx->data_buf = data_buf;
|
||||
ctx->data_size = data_size;
|
||||
ctx->eval_buf = eval_buf;
|
||||
ctx->eval_size = eval_size;
|
||||
ctx->cach_buf = cach_buf;
|
||||
ctx->cach_size = cach_size;
|
||||
|
||||
ctx->device = MTLCreateSystemDefaultDevice();
|
||||
ctx->queue = [ctx->device newCommandQueue];
|
||||
|
@ -208,9 +220,10 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
// TODO: how to use MTLStorageModeManaged?
|
||||
// TODO: see if we can avoid this copy somehow
|
||||
{
|
||||
const void * mem_buffer = ggml_get_mem_buffer(ctx_data);
|
||||
const size_t mem_size = ggml_get_mem_size(ctx_data);
|
||||
void * mem_buffer = data_buf;
|
||||
const size_t mem_size = data_size;
|
||||
|
||||
//ctx->buffer_data = [ctx->device newBufferWithBytesNoCopy:mem_buffer length:mem_size options:MTLResourceStorageModeShared deallocator:nil];
|
||||
ctx->buffer_data = [ctx->device newBufferWithBytes:mem_buffer length:mem_size options:MTLResourceStorageModeShared];
|
||||
|
||||
fprintf(stderr, "%s: allocated data buffer, size = %8.2f MB\n", __func__, mem_size / 1024.0 / 1024.0);
|
||||
|
@ -219,16 +232,26 @@ struct ggml_mtl_context * llama_mtl_init(
|
|||
// pin ctx_eval memory to GPU
|
||||
// this buffer will be used for the intermediate results of the evaluation
|
||||
{
|
||||
const size_t mem_size = ggml_get_mem_size(ctx_eval);
|
||||
const void * mem_buffer = eval_buf;
|
||||
const size_t mem_size = eval_size;
|
||||
|
||||
ctx->buffer_eval = [ctx->device newBufferWithLength:mem_size options:MTLResourceStorageModePrivate];
|
||||
ctx->buffer_eval = [ctx->device newBufferWithBytes:mem_buffer length:mem_size options:MTLResourceStorageModeShared];
|
||||
|
||||
fprintf(stderr, "%s: allocated eval buffer, size = %8.2f MB\n", __func__, mem_size / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
if (cach_buf) {
|
||||
const void * mem_buffer = cach_buf;
|
||||
const size_t mem_size = cach_size;
|
||||
|
||||
ctx->buffer_cach = [ctx->device newBufferWithBytes:mem_buffer length:mem_size options:MTLResourceStorageModeShared];
|
||||
|
||||
fprintf(stderr, "%s: allocated cach buffer, size = %8.2f MB\n", __func__, mem_size / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
// allocate buffer for result extraction
|
||||
{
|
||||
const size_t mem_size = ggml_nbytes(gf->nodes[gf->n_nodes - 1]);
|
||||
const size_t mem_size = outp_size;
|
||||
|
||||
ctx->out = [ctx->device newBufferWithLength:mem_size options:MTLResourceStorageModeShared];
|
||||
|
||||
|
@ -246,30 +269,48 @@ void llama_mtl_free(struct ggml_mtl_context * ctx) {
|
|||
|
||||
// get data / eval buffer + offset
|
||||
id<MTLBuffer> llama_mtl_get_buffer(struct ggml_mtl_context * ctx, struct ggml_tensor * t, size_t * offs) {
|
||||
const int64_t offs_data = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_data);
|
||||
const int64_t offs_eval = (int64_t) t->data - (int64_t) ggml_get_mem_buffer(ctx->ctx_eval);
|
||||
|
||||
const bool is_data = (offs_eval < 0) || (offs_data >= 0 && offs_data < offs_eval);
|
||||
const int64_t offs_data = (int64_t) t->data - (int64_t) ctx->data_buf;
|
||||
const int64_t offs_eval = (int64_t) t->data - (int64_t) ctx->eval_buf;
|
||||
const int64_t offs_cach = (int64_t) t->data - (int64_t) ctx->cach_buf;
|
||||
|
||||
//const size_t t_size = ggml_nbytes(t);
|
||||
const size_t t_offs = is_data ? offs_data : offs_eval;
|
||||
|
||||
id<MTLBuffer> result;
|
||||
size_t t_offs = 0;
|
||||
|
||||
if (is_data) {
|
||||
//fprintf(stderr, "%s: data tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
|
||||
if ( offs_data > 0 &&
|
||||
(offs_eval < 0 || (offs_data < offs_eval)) &&
|
||||
(offs_cach < 0 || (offs_data < offs_cach))
|
||||
) {
|
||||
result = ctx->buffer_data;
|
||||
} else {
|
||||
//fprintf(stderr, "%s: eval tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
|
||||
result = ctx->buffer_eval;
|
||||
t_offs = offs_data;
|
||||
//fprintf(stderr, "%s: data tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
|
||||
}
|
||||
|
||||
if (result == nil) {
|
||||
if ( offs_eval > 0 &&
|
||||
(offs_data < 0 || (offs_eval < offs_data)) &&
|
||||
(offs_cach < 0 || (offs_eval < offs_cach))
|
||||
) {
|
||||
result = ctx->buffer_eval;
|
||||
t_offs = offs_eval;
|
||||
//fprintf(stderr, "%s: data tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
|
||||
}
|
||||
|
||||
if ( offs_cach > 0 &&
|
||||
(offs_data < 0 || (offs_cach < offs_data)) &&
|
||||
(offs_eval < 0 || (offs_cach < offs_eval))
|
||||
) {
|
||||
result = ctx->buffer_cach;
|
||||
t_offs = offs_cach;
|
||||
//fprintf(stderr, "%s: data tensor '%16s', offs = %8ld, size = %8ld\n", __func__, t->name, t_offs, t_size);
|
||||
}
|
||||
|
||||
if (result == nil || (t_offs > ctx->data_size && t_offs > ctx->eval_size && t_offs > ctx->cach_size)) {
|
||||
fprintf(stderr, "%s: error: buffer is nil\n", __func__);
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
if (offs != nil) {
|
||||
if (offs != 0) {
|
||||
*offs = t_offs;
|
||||
}
|
||||
|
||||
|
@ -284,49 +325,9 @@ int llama_mtl_eval(
|
|||
int n_past) {
|
||||
mtl_printf("%s: evaluating, n_tokens = %d, n_past = %d\n", __func__, n_tokens, n_past);
|
||||
|
||||
// adjust dynamic shapes
|
||||
// TODO: wrong ...
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "embd");
|
||||
// t->ne[0] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Qpre");
|
||||
// t->src0->ne[2] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Kpre");
|
||||
// t->src0->ne[2] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Vcur");
|
||||
// t->ne[0] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * k = ggml_graph_get_tensor(gf, "k");
|
||||
// struct ggml_tensor * v = ggml_graph_get_tensor(gf, "v");
|
||||
// k->ne[0] = n_tokens*v->ne[1];
|
||||
// v->ne[0] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Q");
|
||||
// t->ne[1] = n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "K");
|
||||
// t->ne[1] = n_past + n_tokens;
|
||||
//}
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_graph_get_tensor(gf, "KQV_merged_contiguous");
|
||||
// t->src1->ne[1] = n_tokens;
|
||||
//}
|
||||
|
||||
struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd");
|
||||
memcpy(input->data, tokens, n_tokens * sizeof(int));
|
||||
|
||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer];
|
||||
id<MTLComputeCommandEncoder> encoder = nil;
|
||||
|
||||
size_t offs_src0 = 0;
|
||||
size_t offs_src1 = 0;
|
||||
size_t offs_dst = 0;
|
||||
|
@ -340,6 +341,9 @@ int llama_mtl_eval(
|
|||
memcpy((char *) id_dst.contents + offs_src0, embd->data, ggml_nbytes(embd));
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer];
|
||||
id<MTLComputeCommandEncoder> encoder = nil;
|
||||
|
||||
for (int i = 0; i < gf->n_nodes; ++i) {
|
||||
//mtl_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
||||
|
||||
|
@ -791,6 +795,9 @@ int llama_mtl_eval(
|
|||
|
||||
const float * logits = ctx->logits;
|
||||
|
||||
struct ggml_tensor * t = gf->nodes[gf->n_nodes - 1];
|
||||
memcpy(t->data, logits, ggml_nbytes(t));
|
||||
|
||||
#if 1
|
||||
mtl_printf("logits: ");
|
||||
for (int i = 0; i < 100; i++) {
|
||||
|
|
136
llama.cpp
136
llama.cpp
|
@ -243,7 +243,6 @@ struct llama_context {
|
|||
|
||||
// METAL
|
||||
ggml_mtl_context * mtl_ctx = NULL;
|
||||
ggml_cgraph mtl_gf;
|
||||
|
||||
int buf_last = 0;
|
||||
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
||||
|
@ -1262,7 +1261,7 @@ static bool llama_eval_internal(
|
|||
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
lctx.use_buf(ctx0, 0);
|
||||
//lctx.use_buf(ctx0, 0);
|
||||
|
||||
// norm
|
||||
{
|
||||
|
@ -1378,7 +1377,7 @@ static bool llama_eval_internal(
|
|||
cur);
|
||||
}
|
||||
|
||||
lctx.use_buf(ctx0, 1);
|
||||
//lctx.use_buf(ctx0, 1);
|
||||
|
||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
||||
|
||||
|
@ -1416,7 +1415,7 @@ static bool llama_eval_internal(
|
|||
inpL = cur;
|
||||
}
|
||||
|
||||
lctx.use_buf(ctx0, 0);
|
||||
//lctx.use_buf(ctx0, 0);
|
||||
|
||||
// used at the end to optionally extract the embeddings
|
||||
struct ggml_tensor * embeddings = NULL;
|
||||
|
@ -1435,85 +1434,20 @@ static bool llama_eval_internal(
|
|||
// lm_head
|
||||
inpL = ggml_mul_mat(ctx0, model.output, inpL);
|
||||
|
||||
lctx.use_buf(ctx0, -1);
|
||||
//lctx.use_buf(ctx0, -1);
|
||||
|
||||
// logits -> probs
|
||||
//inpL = ggml_soft_max_inplace(ctx0, inpL);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, inpL);
|
||||
|
||||
// METAL
|
||||
if (lctx.mtl_ctx) {
|
||||
llama_mtl_eval(lctx.mtl_ctx, &gf, tokens, n_tokens, n_past);
|
||||
} else {
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
// TODO: not needed anymore, keeping for a bit
|
||||
//// lets export a smaller graph to get things rolling -- baby steps first
|
||||
//{
|
||||
// struct ggml_tensor * t = ggml_get_tensor(ctx0, "mtl-check");
|
||||
// if (!t) {
|
||||
// fprintf(stderr, "%s: failed to find tensor 'mtl-check'\n", __func__);
|
||||
// exit(1);
|
||||
// }
|
||||
// ggml_build_forward_expand(&gf, t);
|
||||
//}
|
||||
|
||||
// print
|
||||
//{
|
||||
// auto print_t_f32 = [&](struct ggml_tensor * t) {
|
||||
// float * data = (float *)t->data;
|
||||
// printf("data: ");
|
||||
// for (int i = 0; i < (int) t->ne[0]; i++) {
|
||||
// printf("%f ", data[i]);
|
||||
// }
|
||||
// printf("\n");
|
||||
// double sum = 0.0;
|
||||
// for (int i = 0; i < ggml_nelements(t); i++) {
|
||||
// double cur = data[i];
|
||||
// if (isinf(cur)) continue;
|
||||
// sum += data[i];
|
||||
// }
|
||||
// printf("sum: %f\n", sum);
|
||||
// };
|
||||
// auto print_t_f16 = [&](struct ggml_tensor * t) {
|
||||
// ggml_fp16_t * data = (ggml_fp16_t *)t->data;
|
||||
// printf("data: ");
|
||||
// for (int i = 0; i < (int) t->ne[0]; i++) {
|
||||
// printf("%f ", ggml_fp16_to_fp32(data[i]));
|
||||
// }
|
||||
// printf("\n");
|
||||
// double sum = 0.0;
|
||||
// printf("nb: %lld %lld %lld %lld\n", t->nb[0], t->nb[1], t->nb[2], t->nb[3]);
|
||||
// for (int64_t i3 = 0; i3 < t->ne[3]; ++i3) {
|
||||
// for (int64_t i2 = 0; i2 < t->ne[2]; ++i2) {
|
||||
// for (int64_t i1 = 0; i1 < t->ne[1]; ++i1) {
|
||||
// for (int64_t i0 = 0; i0 < t->ne[0]; ++i0) {
|
||||
// const size_t offs = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0];
|
||||
// const ggml_fp16_t cur = *((ggml_fp16_t *)((char *) data + offs));
|
||||
// const float curf = ggml_fp16_to_fp32(cur);
|
||||
// if (isinf(curf)) continue;
|
||||
// sum += curf;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// printf("sum: %f\n", sum);
|
||||
// };
|
||||
|
||||
// ggml_graph_compute(ctx0, &gf);
|
||||
|
||||
// {
|
||||
// auto * t = ggml_get_tensor(ctx0, "mtl-check");
|
||||
// switch (t->type) {
|
||||
// case GGML_TYPE_F32:
|
||||
// print_t_f32(t);
|
||||
// break;
|
||||
// case GGML_TYPE_F16:
|
||||
// print_t_f16(t);
|
||||
// break;
|
||||
// default:
|
||||
// fprintf(stderr, "%s: unsupported type\n", __func__);
|
||||
// exit(1);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
if (cgraph_fname) {
|
||||
// TODO: tmp add the vocabulary as a leaf to the computation graph, until better approach is found
|
||||
|
@ -2402,22 +2336,6 @@ struct llama_context * llama_init_from_file(
|
|||
|
||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
// METAL
|
||||
if (params.cgraph) {
|
||||
params.vocab_only = true;
|
||||
|
||||
// load the compute graph
|
||||
struct ggml_context * ctx_data = NULL;
|
||||
struct ggml_context * ctx_eval = NULL;
|
||||
|
||||
struct ggml_cgraph gf = ggml_graph_import("llama.ggml", &ctx_data, &ctx_eval);
|
||||
gf.n_threads = 1;
|
||||
|
||||
// this allocates all Metal resources and memory buffers
|
||||
ctx->mtl_ctx = llama_mtl_init(ctx_data, ctx_eval, NULL, &gf);
|
||||
ctx->mtl_gf = gf;
|
||||
}
|
||||
|
||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
|
||||
params.use_mmap, params.use_mlock, params.vocab_only,
|
||||
params.progress_callback, params.progress_callback_user_data)) {
|
||||
|
@ -2438,11 +2356,7 @@ struct llama_context * llama_init_from_file(
|
|||
const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
|
||||
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||
}
|
||||
}
|
||||
|
||||
// METAL
|
||||
// TODO: changed the behavior here for vocab_only -- reconsider implications later
|
||||
{
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
// resized during inference
|
||||
|
@ -2462,6 +2376,20 @@ struct llama_context * llama_init_from_file(
|
|||
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
|
||||
}
|
||||
|
||||
// METAL
|
||||
if (params.cgraph) {
|
||||
// this allocates all Metal resources and memory buffers
|
||||
//ctx->mtl_ctx = llama_mtl_init(ctx_data, ctx_eval, &gf);
|
||||
ctx->mtl_ctx = llama_mtl_init(
|
||||
ggml_get_mem_buffer(ctx->model.ctx),
|
||||
ggml_get_mem_size (ctx->model.ctx),
|
||||
ctx->buf_compute.addr,
|
||||
ctx->buf_compute.size,
|
||||
ctx->model.kv_self.buf.addr,
|
||||
ctx->model.kv_self.buf.size,
|
||||
32*ctx->model.hparams.n_vocab*sizeof(float));
|
||||
}
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
|
@ -3077,26 +3005,10 @@ int llama_eval(
|
|||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads) {
|
||||
// METAL
|
||||
if (ctx->mtl_ctx) {
|
||||
llama_mtl_eval(ctx->mtl_ctx, &ctx->mtl_gf, tokens, n_tokens, n_past);
|
||||
|
||||
const float * logits = llama_mtl_get_logits(ctx->mtl_ctx);
|
||||
|
||||
// extract logits
|
||||
{
|
||||
const int n_vocab = ctx->model.hparams.n_vocab;
|
||||
auto & logits_out = ctx->logits;
|
||||
|
||||
logits_out.resize(n_vocab);
|
||||
memcpy(logits_out.data(), logits, sizeof(float)*n_vocab);
|
||||
}
|
||||
} else {
|
||||
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) {
|
||||
fprintf(stderr, "%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// get a more accurate load time, upon first eval
|
||||
// TODO: fix this
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue