diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index e46083707..bc9f6fa68 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -322,7 +322,7 @@ static void process_logits( } } -static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) { +static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const int n_ctx = llama_n_ctx(ctx); @@ -335,6 +335,15 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool auto tim2 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); + if (from_chunk > 0) { + if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) { + fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk); + return false; + } + fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx); + tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx); + } + if (int(tokens.size()) < 2*n_ctx) { fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx, n_ctx); @@ -458,6 +467,7 @@ int main(int argc, char ** argv) { std::string prev_result_file; std::string combine_files; bool compute_ppl = true; + int from_chunk = 0; std::vector args; args.push_back(argv[0]); int iarg = 1; @@ -482,6 +492,9 @@ int main(int argc, char ** argv) { prev_result_file = argv[++iarg]; } else if (arg == "--combine") { combine_files = argv[++iarg]; + } + else if (arg == "--from-chunk") { + from_chunk = std::stoi(argv[++iarg]); } else { args.push_back(argv[iarg]); } @@ -590,7 +603,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s\n", get_system_info(params).c_str()); } - bool OK = compute_imatrix(ctx, params, compute_ppl); + bool OK = compute_imatrix(ctx, params, compute_ppl, from_chunk); if (!OK) { return 1; }