diff --git a/main.cpp b/main.cpp index 4b0229a64..ea0a469ca 100644 --- a/main.cpp +++ b/main.cpp @@ -145,8 +145,9 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; - if (n_parts < 1) + if (n_parts < 1) { n_parts = LLAMA_N_PARTS.at(hparams.n_embd); + } fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx);