diff --git a/src/llama.cpp b/src/llama.cpp index 80a0dd0f4..2f000cb5a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -113,7 +113,6 @@ #endif // bump if necessary -#define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_LAYERS 512 #define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 @@ -3657,6 +3656,15 @@ namespace GGUFMeta { using llama_buf_map = std::unordered_map; +// TODO: update when needed or think of some clever automatic way to do this +static size_t llama_model_max_nodes(const llama_model & model) { + if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > 400) { // llama-3 405B + return 32768; + } + + return 8192; +} + struct llama_model_loader { int n_kv = 0; int n_tensors = 0; @@ -8495,7 +8503,7 @@ struct llm_build_context { } struct ggml_cgraph * build_k_shift() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); GGML_ASSERT(kv_self.size == n_ctx); @@ -8526,7 +8534,7 @@ struct llm_build_context { } struct ggml_cgraph * build_s_copy() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); GGML_ASSERT(kv_self.recurrent); @@ -8549,7 +8557,7 @@ struct llm_build_context { } struct ggml_cgraph * build_defrag(const std::vector & ids) { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); for (uint32_t i = 0; i < ids.size(); ++i) { const uint32_t id = ids[i]; @@ -8790,7 +8798,7 @@ struct llm_build_context { } struct ggml_cgraph * build_llama() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -8933,7 +8941,7 @@ struct llm_build_context { } struct ggml_cgraph * build_baichuan() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9048,7 +9056,7 @@ struct llm_build_context { } struct ggml_cgraph * build_xverse() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9151,7 +9159,7 @@ struct llm_build_context { } struct ggml_cgraph * build_falcon() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -9271,7 +9279,7 @@ struct llm_build_context { } struct ggml_cgraph * build_grok() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -9428,7 +9436,7 @@ struct llm_build_context { } struct ggml_cgraph * build_dbrx() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -9554,7 +9562,7 @@ struct llm_build_context { } struct ggml_cgraph * build_starcoder() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -9658,7 +9666,7 @@ struct llm_build_context { } struct ggml_cgraph * build_refact() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9752,7 +9760,7 @@ struct llm_build_context { } struct ggml_cgraph * build_bert() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -9946,7 +9954,7 @@ struct llm_build_context { } struct ggml_cgraph * build_bloom() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -10047,7 +10055,7 @@ struct llm_build_context { } struct ggml_cgraph * build_mpt() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -10337,7 +10345,7 @@ struct llm_build_context { } struct ggml_cgraph * build_qwen() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10449,7 +10457,7 @@ struct llm_build_context { } struct ggml_cgraph * build_qwen2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10561,7 +10569,7 @@ struct llm_build_context { } struct ggml_cgraph * build_qwen2moe() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -10707,7 +10715,7 @@ struct llm_build_context { } struct ggml_cgraph * build_phi2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -10828,7 +10836,7 @@ struct llm_build_context { } struct ggml_cgraph * build_phi3() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -11060,7 +11068,7 @@ struct llm_build_context { } struct ggml_cgraph * build_gpt2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -11165,7 +11173,7 @@ struct llm_build_context { } struct ggml_cgraph * build_codeshell() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -11276,7 +11284,7 @@ struct llm_build_context { } struct ggml_cgraph * build_orion() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11394,7 +11402,7 @@ struct llm_build_context { } struct ggml_cgraph * build_internlm2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11515,7 +11523,7 @@ struct llm_build_context { // https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738 // based on the original build_llama() function struct ggml_cgraph * build_minicpm() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11659,7 +11667,7 @@ struct llm_build_context { } struct ggml_cgraph * build_gemma() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -11767,7 +11775,7 @@ struct llm_build_context { } struct ggml_cgraph * build_gemma2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head_k = hparams.n_embd_head_k; @@ -11902,7 +11910,7 @@ struct llm_build_context { struct ggml_cgraph * build_starcoder2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -12021,7 +12029,7 @@ struct llm_build_context { } struct ggml_cgraph * build_mamba() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t d_model = n_embd; const int64_t d_conv = hparams.ssm_d_conv; @@ -12170,7 +12178,7 @@ struct llm_build_context { struct ggml_cgraph * build_command_r() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -12324,7 +12332,7 @@ struct llm_build_context { // * removed bias // * removed MoE struct ggml_cgraph * build_olmo() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -12448,7 +12456,7 @@ struct llm_build_context { } struct ggml_cgraph * build_openelm() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -12573,7 +12581,7 @@ struct llm_build_context { } struct ggml_cgraph * build_gptneox() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -12715,7 +12723,7 @@ struct llm_build_context { } struct ggml_cgraph * build_arctic() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -12847,7 +12855,7 @@ struct llm_build_context { } struct ggml_cgraph * build_deepseek2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -13075,7 +13083,7 @@ struct llm_build_context { } struct ggml_cgraph * build_bitnet() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13215,7 +13223,7 @@ struct llm_build_context { } struct ggml_cgraph * build_t5() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens int32_t n_tokens = this->n_tokens; @@ -13532,7 +13540,7 @@ struct llm_build_context { } struct ggml_cgraph * build_jais() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -13624,7 +13632,7 @@ struct llm_build_context { } struct ggml_cgraph * build_chatglm() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -14951,9 +14959,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { // each move requires 6*n_layer tensors (see build_defrag) // - source view, destination view, copy operation // - x2 for keys and values - //const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer); + //const uint32_t max_moves = llama_model_max_nodes(model)/(6*n_layer); // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer); + const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -19351,8 +19359,10 @@ struct llama_context * llama_new_context_with_model( } } + const size_t max_nodes = llama_model_max_nodes(*model); + // buffer used to store the computation graph and the tensor meta data - ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES, false)); + ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = @@ -19365,7 +19375,7 @@ struct llama_context * llama_new_context_with_model( // currently this is only implemented in the CUDA backend pipeline_parallel = false; #endif - ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES, pipeline_parallel); + ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), max_nodes, pipeline_parallel); if (pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));