ggml : add graph tensor allocator (#2411)
* ggml : add graph tensor allocator * ggml : don't calculate data pointer of unallocated tensors when creating a view with an offset * ggml : refactor ggml_view_Nd into ggml_view_tensor_offset
This commit is contained in:
parent
11f3ca06b8
commit
a113689571
7 changed files with 813 additions and 89 deletions
242
llama.cpp
242
llama.cpp
|
@ -56,8 +56,14 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
#if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_METAL)
|
||||
#include "ggml-alloc.h"
|
||||
#define LLAMA_USE_ALLOCATOR
|
||||
#else
|
||||
#define LLAMA_USE_SCRATCH
|
||||
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
||||
#endif
|
||||
|
||||
|
||||
// available llama models
|
||||
enum e_model {
|
||||
|
@ -327,13 +333,22 @@ struct llama_model {
|
|||
|
||||
struct llama_context {
|
||||
llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
|
||||
#ifdef GGML_USE_METAL
|
||||
~llama_context() {
|
||||
if (model_owner) {
|
||||
delete &model;
|
||||
}
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ctx_metal) {
|
||||
ggml_metal_free(ctx_metal);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef LLAMA_USE_ALLOCATOR
|
||||
if (alloc) {
|
||||
ggml_allocr_free(alloc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::mt19937 rng;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
|
@ -371,7 +386,17 @@ struct llama_context {
|
|||
// memory buffers used to evaluate the model
|
||||
// TODO: move in llama_state
|
||||
llama_ctx_buffer buf_compute;
|
||||
|
||||
#ifdef LLAMA_USE_ALLOCATOR
|
||||
llama_ctx_buffer buf_alloc;
|
||||
ggml_allocr * alloc = NULL;
|
||||
#endif
|
||||
|
||||
#ifdef LLAMA_USE_SCRATCH
|
||||
llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
|
||||
int buf_last = 0;
|
||||
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
ggml_metal_context * ctx_metal = NULL;
|
||||
|
@ -381,9 +406,6 @@ struct llama_context {
|
|||
ggml_mpi_context * ctx_mpi = NULL;
|
||||
#endif
|
||||
|
||||
int buf_last = 0;
|
||||
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
||||
|
||||
void use_buf(struct ggml_context * ctx, int i) {
|
||||
#if defined(LLAMA_USE_SCRATCH)
|
||||
size_t last_size = 0;
|
||||
|
@ -1230,12 +1252,16 @@ static void llama_model_load_internal(
|
|||
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
|
||||
|
||||
// this is the total memory required to run the inference
|
||||
const size_t mem_required =
|
||||
size_t mem_required =
|
||||
ctx_size +
|
||||
mmapped_size - vram_weights + // weights in VRAM not in memory
|
||||
mmapped_size - vram_weights; // weights in VRAM not in memory
|
||||
|
||||
#ifndef LLAMA_USE_ALLOCATOR
|
||||
mem_required +=
|
||||
MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
|
||||
MEM_REQ_SCRATCH1().at(model.type) +
|
||||
MEM_REQ_EVAL().at(model.type);
|
||||
#endif
|
||||
|
||||
// this is the memory required by one llama_state
|
||||
const size_t mem_required_state =
|
||||
|
@ -1360,32 +1386,15 @@ static bool llama_model_load(
|
|||
}
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - tokens: new batch of tokens to process
|
||||
// - embd embeddings input
|
||||
// - n_tokens number of tokens
|
||||
// - n_past: the context size so far
|
||||
// - n_threads: number of threads to use
|
||||
//
|
||||
static bool llama_eval_internal(
|
||||
static struct ggml_cgraph * llama_build_graph(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads,
|
||||
const char * cgraph_fname) {
|
||||
int n_past) {
|
||||
|
||||
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
|
||||
#endif
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
const int N = n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
|
@ -1401,10 +1410,8 @@ static bool llama_eval_internal(
|
|||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
|
||||
LLAMA_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
const float freq_base = hparams.rope_freq_base;
|
||||
|
@ -1416,26 +1423,35 @@ static bool llama_eval_internal(
|
|||
auto & mem_per_token = lctx.mem_per_token;
|
||||
auto & buf_compute = lctx.buf_compute;
|
||||
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.addr,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
#ifdef LLAMA_USE_ALLOCATOR
|
||||
params.no_alloc = true;
|
||||
#endif
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
if (tokens) {
|
||||
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
|
||||
#ifdef LLAMA_USE_ALLOCATOR
|
||||
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
||||
}
|
||||
#else
|
||||
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
||||
#endif
|
||||
ggml_set_name(inp_tokens, "inp_tokens");
|
||||
|
||||
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
||||
|
@ -1445,7 +1461,15 @@ static bool llama_eval_internal(
|
|||
#endif
|
||||
|
||||
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
||||
|
||||
#ifdef LLAMA_USE_ALLOCATOR
|
||||
ggml_allocr_alloc(lctx.alloc, inpL);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
||||
}
|
||||
#else
|
||||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
||||
#endif
|
||||
}
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
|
@ -1472,6 +1496,17 @@ static bool llama_eval_internal(
|
|||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
#ifdef LLAMA_USE_ALLOCATOR
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
}
|
||||
#else
|
||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
#endif
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_format_name(inpL, "layer_inp_%d", il);
|
||||
|
||||
|
@ -1567,9 +1602,6 @@ static bool llama_eval_internal(
|
|||
ggml_set_name(KQ, "KQ");
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd_head)
|
||||
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
|
||||
// KQ_scaled shape [n_past + N, N, n_head, 1]
|
||||
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
|
||||
offload_func_kq(KQ_scaled);
|
||||
|
@ -1685,9 +1717,6 @@ static bool llama_eval_internal(
|
|||
|
||||
lctx.use_buf(ctx0, 0);
|
||||
|
||||
// used at the end to optionally extract the embeddings
|
||||
struct ggml_tensor * embeddings = NULL;
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
|
@ -1698,8 +1727,6 @@ static bool llama_eval_internal(
|
|||
cur = ggml_mul(ctx0, cur, model.norm);
|
||||
// offload_func_nr(cur); // TODO CPU + GPU mirrored backend
|
||||
ggml_set_name(cur, "result_norm");
|
||||
|
||||
embeddings = cur;
|
||||
}
|
||||
|
||||
// lm_head
|
||||
|
@ -1711,12 +1738,82 @@ static bool llama_eval_internal(
|
|||
// logits -> probs
|
||||
//cur = ggml_soft_max_inplace(ctx0, cur);
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
|
||||
ggml_used_mem(ctx0)/1024.0/1024.0,
|
||||
lctx.get_buf_max_mem(0)/1024.0/1024.0,
|
||||
lctx.get_buf_max_mem(1)/1024.0/1024.0,
|
||||
lctx.work_buffer.size()/1024.0/1024.0,
|
||||
n_past, N);
|
||||
#endif
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - tokens: new batch of tokens to process
|
||||
// - embd embeddings input
|
||||
// - n_tokens number of tokens
|
||||
// - n_past: the context size so far
|
||||
// - n_threads: number of threads to use
|
||||
//
|
||||
static bool llama_eval_internal(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past,
|
||||
int n_threads,
|
||||
const char * cgraph_fname) {
|
||||
|
||||
LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
|
||||
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads);
|
||||
#endif
|
||||
|
||||
const int N = n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
LLAMA_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
|
||||
#ifdef LLAMA_USE_ALLOCATOR
|
||||
ggml_allocr_reset(lctx.alloc);
|
||||
#endif
|
||||
|
||||
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
|
||||
|
||||
#ifdef LLAMA_USE_ALLOCATOR
|
||||
ggml_allocr_alloc_graph(lctx.alloc, gf);
|
||||
#endif
|
||||
|
||||
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
|
||||
|
||||
#if GGML_USE_MPI
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer);
|
||||
#endif
|
||||
|
||||
|
@ -1760,6 +1857,10 @@ static bool llama_eval_internal(
|
|||
lctx.kv_self.n = n_past + N;
|
||||
|
||||
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
|
||||
struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2];
|
||||
|
||||
LLAMA_ASSERT(strcmp(res->name, "result_output") == 0);
|
||||
LLAMA_ASSERT(strcmp(embeddings->name, "result_norm") == 0);
|
||||
|
||||
if (cgraph_fname) {
|
||||
ggml_graph_export(gf, cgraph_fname);
|
||||
|
@ -1798,21 +1899,6 @@ static bool llama_eval_internal(
|
|||
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
|
||||
}
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
|
||||
#if 0
|
||||
printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__,
|
||||
ggml_used_mem(ctx0)/1024.0/1024.0,
|
||||
lctx.get_buf_max_mem(0)/1024.0/1024.0,
|
||||
lctx.get_buf_max_mem(1)/1024.0/1024.0,
|
||||
lctx.work_buffer.size()/1024.0/1024.0,
|
||||
n_past, N);
|
||||
#endif
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
// measure the performance only for the single-token evals
|
||||
if (N == 1) {
|
||||
lctx.t_eval_us += ggml_time_us() - t_start_us;
|
||||
|
@ -3180,10 +3266,47 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->embedding.resize(hparams.n_embd);
|
||||
}
|
||||
|
||||
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
|
||||
#ifdef LLAMA_USE_ALLOCATOR
|
||||
{
|
||||
static const size_t tensor_alignment = 32;
|
||||
// the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data
|
||||
ctx->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
// create measure allocator
|
||||
ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
|
||||
|
||||
// build worst-case graph
|
||||
int n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
|
||||
int n_past = hparams.n_ctx - n_tokens;
|
||||
llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past);
|
||||
|
||||
// measure memory requirements for the graph
|
||||
size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment;
|
||||
|
||||
fprintf(stderr, "%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0);
|
||||
|
||||
// debug - for comparison with scratch buffer
|
||||
//size_t prev_req =
|
||||
// MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) +
|
||||
// MEM_REQ_SCRATCH1().at(ctx->model.type) +
|
||||
// MEM_REQ_EVAL().at(ctx->model.type);
|
||||
//fprintf(stderr, "%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0);
|
||||
|
||||
// recreate allocator with exact memory requirements
|
||||
ggml_allocr_free(ctx->alloc);
|
||||
|
||||
ctx->buf_alloc.resize(alloc_size);
|
||||
ctx->alloc = ggml_allocr_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, tensor_alignment);
|
||||
}
|
||||
#else
|
||||
ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead());
|
||||
#endif
|
||||
|
||||
#ifdef LLAMA_USE_SCRATCH
|
||||
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
|
||||
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
|
@ -3253,9 +3376,6 @@ struct llama_context * llama_init_from_file(
|
|||
}
|
||||
|
||||
void llama_free(struct llama_context * ctx) {
|
||||
if (ctx->model_owner) {
|
||||
delete &ctx->model;
|
||||
}
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue