imatrix: be able to start from a specific chunk

This commit is contained in:
Iwan Kawrakow 2024-02-03 13:30:39 +02:00
parent 935227bf32
commit 4e0d6dd9c1

View file

@ -322,7 +322,7 @@ static void process_logits(
} }
} }
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) { static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
@ -335,6 +335,15 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
auto tim2 = std::chrono::high_resolution_clock::now(); auto tim2 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count()); fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
if (from_chunk > 0) {
if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) {
fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk);
return false;
}
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx);
tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx);
}
if (int(tokens.size()) < 2*n_ctx) { if (int(tokens.size()) < 2*n_ctx) {
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx, fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx,
n_ctx); n_ctx);
@ -458,6 +467,7 @@ int main(int argc, char ** argv) {
std::string prev_result_file; std::string prev_result_file;
std::string combine_files; std::string combine_files;
bool compute_ppl = true; bool compute_ppl = true;
int from_chunk = 0;
std::vector<char*> args; std::vector<char*> args;
args.push_back(argv[0]); args.push_back(argv[0]);
int iarg = 1; int iarg = 1;
@ -482,6 +492,9 @@ int main(int argc, char ** argv) {
prev_result_file = argv[++iarg]; prev_result_file = argv[++iarg];
} else if (arg == "--combine") { } else if (arg == "--combine") {
combine_files = argv[++iarg]; combine_files = argv[++iarg];
}
else if (arg == "--from-chunk") {
from_chunk = std::stoi(argv[++iarg]);
} else { } else {
args.push_back(argv[iarg]); args.push_back(argv[iarg]);
} }
@ -590,7 +603,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s\n", get_system_info(params).c_str()); fprintf(stderr, "%s\n", get_system_info(params).c_str());
} }
bool OK = compute_imatrix(ctx, params, compute_ppl); bool OK = compute_imatrix(ctx, params, compute_ppl, from_chunk);
if (!OK) { if (!OK) {
return 1; return 1;
} }