diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0f4b62c43..182e9e6dd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3574,7 +3574,7 @@ size_t llama_model::size() const { } size_t llama_model::max_nodes() const { - return std::max(8192, tensors_by_name.size()*5); + return std::max(65536, tensors_by_name.size()*5); } size_t llama_model::n_devices() const { diff --git a/src/llama.cpp b/src/llama.cpp index 607f27861..e4d893462 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1092,7 +1092,8 @@ struct llm_build_context { llama_context & lctx, const llama_ubatch & ubatch, const llm_build_cb & cb, - bool worst_case) : + bool worst_case, + bool warmup) : model (lctx.model), lctx (lctx), hparams (model.hparams), @@ -1110,7 +1111,7 @@ struct llm_build_context { n_embd_head_v (hparams.n_embd_head_v), n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), - n_expert_used (hparams.n_expert_used), + n_expert_used (warmup ? hparams.n_expert : hparams.n_expert_used), freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor), @@ -8118,7 +8119,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -8135,7 +8136,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - struct llm_build_context llm(lctx, dummy, cb, false); + struct llm_build_context llm(lctx, dummy, cb, false, false); llm.init(); @@ -8186,7 +8187,11 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_cgraph * result = NULL; - struct llm_build_context llm(lctx, ubatch, cb, worst_case); + const llama_vocab * vocab = llama_model_get_vocab(&model); + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); + bool is_warming_up = (ubatch.n_tokens == 2 && ubatch.token[0] == bos && ubatch.token[1] == eos); + struct llm_build_context llm(lctx, ubatch, cb, worst_case, is_warming_up); llm.init();