update llama.cpp, clip.cpp, export-lora.cpp

This commit is contained in:
slaren 2024-02-11 14:58:14 +01:00
parent b13ef36316
commit 7f44f296bc
4 changed files with 171 additions and 154 deletions

View file

@ -337,24 +337,14 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
params.mem_buffer = NULL; params.mem_buffer = NULL;
params.no_alloc = true; params.no_alloc = true;
struct ggml_context * ctx = NULL; struct ggml_context * ctx = NULL;
struct ggml_allocr * alloc = NULL; struct ggml_gallocr * alloc = NULL;
struct ggml_cgraph * gf = NULL; struct ggml_cgraph * gf = NULL;
ctx = ggml_init(params); ctx = ggml_init(params);
alloc = ggml_allocr_new_measure(tensor_alignment); alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling); gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf);
ggml_allocr_free(alloc);
ggml_free(ctx);
static std::vector<uint8_t> data_compute; ggml_gallocr_alloc_graph(alloc, gf);
data_compute.resize(alloc_size + tensor_alignment);
ctx = ggml_init(params);
alloc = ggml_allocr_new(data_compute.data(), data_compute.size(), tensor_alignment);
gf = build_graph_lora(ctx, tensor, lora_a, lora_b, scaling);
ggml_allocr_alloc_graph(alloc, gf);
ggml_allocr_free(alloc);
struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads); struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads);
static std::vector<uint8_t> data_work; static std::vector<uint8_t> data_work;
@ -363,6 +353,7 @@ static bool apply_lora(struct ggml_tensor * tensor, struct lora_data * lora, int
ggml_graph_compute(gf, &cplan); ggml_graph_compute(gf, &cplan);
ggml_gallocr_free(alloc);
ggml_free(ctx); ggml_free(ctx);
return true; return true;
} }

View file

@ -367,7 +367,7 @@ struct clip_ctx {
ggml_backend_buffer_t params_buffer = NULL; ggml_backend_buffer_t params_buffer = NULL;
ggml_backend_buffer_t compute_buffer = NULL; ggml_backend_buffer_t compute_buffer = NULL;
ggml_backend_t backend = NULL; ggml_backend_t backend = NULL;
ggml_allocr * compute_alloc = NULL; ggml_gallocr_t compute_alloc = NULL;
}; };
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) { static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
@ -405,31 +405,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
struct ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size); struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size, image_size, 3, batch_size);
ggml_allocr_alloc(ctx->compute_alloc, inp_raw); ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
if (!ggml_allocr_is_measure(ctx->compute_alloc)) {
float * data = (float *)malloc(ggml_nbytes(inp_raw));
for (size_t i = 0; i < imgs->size; i++) {
const int nx = imgs->data[i].nx;
const int ny = imgs->data[i].ny;
GGML_ASSERT(nx == image_size && ny == image_size);
const int n = nx * ny;
for (int b = 0; b < batch_size; b++) {
for (int k = 0; k < 3; k++) {
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
}
}
}
}
}
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
free(data);
}
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
@ -438,13 +415,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
// concat class_embeddings and patch_embeddings // concat class_embeddings and patch_embeddings
struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); struct ggml_tensor * embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
ggml_allocr_alloc(ctx->compute_alloc, embeddings); ggml_set_name(embeddings, "embeddings");
if (!ggml_allocr_is_measure(ctx->compute_alloc)) { ggml_set_input(embeddings);
void* zero_mem = malloc(ggml_nbytes(embeddings));
memset(zero_mem, 0, ggml_nbytes(embeddings));
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
free(zero_mem);
}
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
@ -453,15 +425,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions);
ggml_allocr_alloc(ctx->compute_alloc, positions); ggml_set_name(positions, "positions");
if (!ggml_allocr_is_measure(ctx->compute_alloc)) { ggml_set_input(positions);
int* positions_data = (int*)malloc(ggml_nbytes(positions));
for (int i = 0; i < num_positions; i++) {
positions_data[i] = i;
}
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
embeddings = embeddings =
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
@ -560,15 +525,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
ggml_allocr_alloc(ctx->compute_alloc, patches); ggml_set_name(patches, "patches");
if (!ggml_allocr_is_measure(ctx->compute_alloc)) { ggml_set_input(patches);
int* patches_data = (int*)malloc(ggml_nbytes(patches));
for (int i = 0; i < num_patches; i++) {
patches_data[i] = i + 1;
}
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
free(patches_data);
}
// shape [1, 576, 1024] // shape [1, 576, 1024]
// ne is whcn, ne = [1024, 576, 1, 1] // ne is whcn, ne = [1024, 576, 1, 1]
@ -926,11 +884,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// alloc memory and offload data // alloc memory and offload data
new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size); new_clip->params_buffer = ggml_backend_alloc_buffer(new_clip->backend, buffer_size);
ggml_allocr* alloc = ggml_allocr_new_from_buffer(new_clip->params_buffer); ggml_tallocr_t alloc = ggml_tallocr_new(new_clip->params_buffer);
for (int i = 0; i < n_tensors; ++i) { for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i); const char * name = gguf_get_tensor_name(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name); struct ggml_tensor * cur = ggml_get_tensor(new_clip->ctx_data, name);
ggml_allocr_alloc(alloc, cur); ggml_tallocr_alloc(alloc, cur);
const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i); const size_t offset = gguf_get_data_offset(ctx) + gguf_get_tensor_offset(ctx, i);
fin.seekg(offset, std::ios::beg); fin.seekg(offset, std::ios::beg);
if (!fin) { if (!fin) {
@ -949,7 +907,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
} }
} }
ggml_allocr_free(alloc); ggml_tallocr_free(alloc);
fin.close(); fin.close();
} }
@ -1077,15 +1035,12 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
// measure mem requirement and allocate // measure mem requirement and allocate
{ {
new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); new_clip->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
new_clip->compute_alloc = ggml_allocr_new_measure_from_backend(new_clip->backend); new_clip->compute_alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(new_clip->backend));
clip_image_f32_batch batch; clip_image_f32_batch batch;
batch.size = 1; batch.size = 1;
ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch); ggml_cgraph * gf = clip_image_build_graph(new_clip, &batch);
size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_clip->compute_alloc, gf); ggml_gallocr_reserve(new_clip->compute_alloc, gf);
ggml_allocr_free(new_clip->compute_alloc); size_t compute_memory_buffer_size = ggml_gallocr_get_buffer_size(new_clip->compute_alloc, 0);
new_clip->compute_buffer = ggml_backend_alloc_buffer(new_clip->backend, compute_memory_buffer_size);
new_clip->compute_alloc = ggml_allocr_new_from_buffer(new_clip->compute_buffer);
printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0); printf("%s: compute allocated memory: %.2f MB\n", __func__, compute_memory_buffer_size /1024.0/1024.0);
} }
@ -1267,12 +1222,72 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
GGML_ASSERT(batch_size == 1); // TODO: support multiple images GGML_ASSERT(batch_size == 1); // TODO: support multiple images
} }
// reset alloc buffer to clean the memory from previous invocations
ggml_allocr_reset(ctx->compute_alloc);
// build the inference graph // build the inference graph
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
ggml_allocr_alloc_graph(ctx->compute_alloc, gf); ggml_gallocr_alloc_graph(ctx->compute_alloc, gf);
// set inputs
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size = hparams.image_size;
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size / patch_size) * (image_size / patch_size));
const int num_positions = num_patches + 1;
{
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
float * data = (float *)malloc(ggml_nbytes(inp_raw));
for (size_t i = 0; i < imgs->size; i++) {
const int nx = imgs->data[i].nx;
const int ny = imgs->data[i].ny;
GGML_ASSERT(nx == image_size && ny == image_size);
const int n = nx * ny;
for (int b = 0; b < batch_size; b++) {
for (int k = 0; k < 3; k++) {
for (int y = 0; y < ny; y++) {
for (int x = 0; x < nx; x++) {
data[(b * 3 * n) + k * n + y * nx + x] = imgs->data[b].buf[3 * (y * nx + x) + k];
}
}
}
}
}
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
free(data);
}
{
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
void* zero_mem = malloc(ggml_nbytes(embeddings));
memset(zero_mem, 0, ggml_nbytes(embeddings));
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
free(zero_mem);
}
{
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
for (int i = 0; i < num_positions; i++) {
positions_data[i] = i;
}
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
{
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
int* patches_data = (int*)malloc(ggml_nbytes(patches));
for (int i = 0; i < num_patches; i++) {
patches_data[i] = i + 1;
}
ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
free(patches_data);
}
if (ggml_backend_is_cpu(ctx->backend)) { if (ggml_backend_is_cpu(ctx->backend)) {
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads); ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);

View file

@ -1411,6 +1411,8 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
return false; return false;
} }
} }
return true;
} }
static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { static bool ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {

161
llama.cpp
View file

@ -1839,8 +1839,6 @@ struct llama_context {
// memory buffers used to evaluate the model // memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta; std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_t sched = nullptr; ggml_backend_sched_t sched = nullptr;
// allocator for the input tensors
ggml_tallocr * alloc = nullptr;
// input tensors // input tensors
ggml_backend_buffer_t buf_input = nullptr; ggml_backend_buffer_t buf_input = nullptr;
@ -6996,12 +6994,10 @@ struct llm_build_context {
static struct ggml_cgraph * llama_build_graph( static struct ggml_cgraph * llama_build_graph(
llama_context & lctx, llama_context & lctx,
const llama_batch & batch) { const llama_batch & batch,
bool worst_case) {
const auto & model = lctx.model; const auto & model = lctx.model;
// check if we should build the worst-case graph (for memory measurement)
const bool worst_case = ggml_tallocr_is_measure(lctx.alloc);
// this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) { llm_build_cb cb = [&](struct ggml_tensor * cur, const char * name, int il) {
if (il >= 0) { if (il >= 0) {
@ -7022,67 +7018,6 @@ static struct ggml_cgraph * llama_build_graph(
struct llm_build_context llm(lctx, batch, cb, worst_case); struct llm_build_context llm(lctx, batch, cb, worst_case);
//
// set input data
//
if (!ggml_tallocr_is_measure(lctx.alloc)) {
if (batch.token) {
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
}
if (batch.embd) {
const int64_t n_embd = llm.n_embd;
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
}
if (batch.pos) {
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
{
const int64_t n_kv = llm.n_kv;
const int64_t n_tokens = batch.n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
float * data = (float *) lctx.inp_KQ_mask->data;
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = -INFINITY;
} else {
f = 0;
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
}
}
}
if (llm.do_rope_shift) {
const int64_t n_ctx = llm.n_ctx;
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
for (int i = 0; i < n_ctx; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
}
llm.init(); llm.init();
switch (model.arch) { switch (model.arch) {
@ -7167,6 +7102,73 @@ static struct ggml_cgraph * llama_build_graph(
return result; return result;
} }
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
//
// set input data
//
const auto & hparams = lctx.model.hparams;
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
if (batch.token) {
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
}
if (batch.embd) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
}
if (batch.pos) {
const int64_t n_tokens = batch.n_tokens;
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
{
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
float * data = (float *) lctx.inp_KQ_mask->data;
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];
for (int i = 0; i < n_kv; ++i) {
float f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = -INFINITY;
} else {
f = 0;
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
}
}
}
if (kv_self.has_shift) {
const int64_t n_ctx = cparams.n_ctx;
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
for (int i = 0; i < n_ctx; ++i) {
data[i] = lctx.kv_self.cells[i].delta;
}
}
}
// decode a batch of tokens by evaluating the transformer // decode a batch of tokens by evaluating the transformer
// //
// - lctx: llama context // - lctx: llama context
@ -7265,7 +7267,7 @@ static int llama_decode_internal(
ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);
ggml_cgraph * gf = llama_build_graph(lctx, batch); ggml_cgraph * gf = llama_build_graph(lctx, batch, false);
// the output is always the last tensor in the graph // the output is always the last tensor in the graph
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
@ -7305,6 +7307,9 @@ static int llama_decode_internal(
if (lctx.backend_cpu != nullptr) { if (lctx.backend_cpu != nullptr) {
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads); ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
} }
llama_set_inputs(lctx, batch);
ggml_backend_sched_graph_compute(lctx.sched, gf); ggml_backend_sched_graph_compute(lctx.sched, gf);
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched)); // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
@ -10841,23 +10846,27 @@ struct llama_context * llama_new_context_with_model(
ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead()); ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead());
ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES); ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES);
ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu);
// build worst-case graph // build worst-case graph
int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch); int n_tokens = (int)std::min(cparams.n_ctx, cparams.n_batch);
int n_past = cparams.n_ctx - n_tokens; int n_past = cparams.n_ctx - n_tokens;
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0)); ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, n_past, 0), true);
// initialize scheduler with the worst-case graph // initialize scheduler with the worst-case graph
ggml_backend_sched_init_measure(ctx->sched, gf); if (!ggml_backend_sched_reserve(ctx->sched, gf)) {
ctx->alloc = ggml_backend_sched_get_tallocr(ctx->sched, ctx->backend_cpu); LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
llama_free(ctx);
return nullptr;
}
for (ggml_backend_t backend : ctx->backends) { for (size_t i = 0; i < ctx->backends.size(); i++) {
ggml_backend_buffer_t buf = ggml_backend_sched_get_buffer(ctx->sched, backend); ggml_backend_t backend = ctx->backends[i];
ggml_backend_buffer_type_t buft = backend_buft[i];
size_t size = ggml_backend_sched_get_buffer_size(ctx->sched, backend);
LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
ggml_backend_buffer_name(buf), ggml_backend_buft_name(buft),
ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0); size / 1024.0 / 1024.0);
} }
// note: the number of splits during measure is higher than during inference due to the kv shift // note: the number of splits during measure is higher than during inference due to the kv shift