llama.cpp : allocate graph in the context
ggml-ci
This commit is contained in:
parent
567b5e24ed
commit
77d662faa5
2 changed files with 15 additions and 15 deletions
2
ggml.c
2
ggml.c
|
@ -4212,7 +4212,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ggml_tensor_overhead(void) {
|
size_t ggml_tensor_overhead(void) {
|
||||||
return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; // REVIEW: i don't think we need to 16 here because GGML_OBJECT_SIZE and GGML_TENSOR_SIZE are already aligned
|
return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; // REVIEW: i don't think we need to add 16 here because GGML_OBJECT_SIZE and GGML_TENSOR_SIZE are already aligned
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
bool ggml_is_transposed(const struct ggml_tensor * tensor) {
|
||||||
|
|
28
llama.cpp
28
llama.cpp
|
@ -1424,7 +1424,7 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
|
|
||||||
ggml_cgraph gf = {};
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||||
|
@ -1541,8 +1541,8 @@ static bool llama_eval_internal(
|
||||||
ggml_set_name(v, "v");
|
ggml_set_name(v, "v");
|
||||||
|
|
||||||
// important: storing RoPE-ed version of K in the KV cache!
|
// important: storing RoPE-ed version of K in the KV cache!
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
|
@ -1712,7 +1712,7 @@ static bool llama_eval_internal(
|
||||||
//cur = ggml_soft_max_inplace(ctx0, cur);
|
//cur = ggml_soft_max_inplace(ctx0, cur);
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
|
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
|
||||||
|
|
||||||
|
@ -1723,10 +1723,10 @@ static bool llama_eval_internal(
|
||||||
#ifdef GGML_USE_METAL
|
#ifdef GGML_USE_METAL
|
||||||
if (lctx.ctx_metal && N == 1) {
|
if (lctx.ctx_metal && N == 1) {
|
||||||
if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
|
if (!ggml_metal_if_optimized(lctx.ctx_metal)) {
|
||||||
ggml_metal_graph_find_concurrency(lctx.ctx_metal,&gf);
|
ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf);
|
||||||
}
|
}
|
||||||
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
|
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
|
||||||
ggml_metal_graph_compute(lctx.ctx_metal, &gf);
|
ggml_metal_graph_compute(lctx.ctx_metal, gf);
|
||||||
ggml_metal_get_tensor (lctx.ctx_metal, cur);
|
ggml_metal_get_tensor (lctx.ctx_metal, cur);
|
||||||
} else {
|
} else {
|
||||||
// IMPORTANT:
|
// IMPORTANT:
|
||||||
|
@ -1745,34 +1745,34 @@ static bool llama_eval_internal(
|
||||||
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
|
ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads);
|
ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if GGML_USE_MPI
|
#if GGML_USE_MPI
|
||||||
ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer);
|
ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// update kv token count
|
// update kv token count
|
||||||
lctx.kv_self.n = n_past + N;
|
lctx.kv_self.n = n_past + N;
|
||||||
|
|
||||||
struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1];
|
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
if (cgraph_fname) {
|
if (cgraph_fname) {
|
||||||
ggml_graph_export(&gf, cgraph_fname);
|
ggml_graph_export(gf, cgraph_fname);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_PERF
|
#ifdef GGML_PERF
|
||||||
// print timing information per ggml operation (for debugging purposes)
|
// print timing information per ggml operation (for debugging purposes)
|
||||||
// requires GGML_PERF to be defined
|
// requires GGML_PERF to be defined
|
||||||
ggml_graph_print(&gf);
|
ggml_graph_print(gf);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// plot the computation graph in dot format (for debugging purposes)
|
// plot the computation graph in dot format (for debugging purposes)
|
||||||
//if (n_past%100 == 0) {
|
//if (n_past%100 == 0) {
|
||||||
// ggml_graph_dump_dot(&gf, NULL, "llama.dot");
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||||
//}
|
//}
|
||||||
|
|
||||||
// extract logits
|
// extract logits
|
||||||
|
@ -3177,7 +3177,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->embedding.resize(hparams.n_embd);
|
ctx->embedding.resize(hparams.n_embd);
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type));
|
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
|
||||||
|
|
||||||
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
|
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
|
||||||
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
|
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue