imatrix: be able to start from a specific chunk
This commit is contained in:
parent
935227bf32
commit
4e0d6dd9c1
1 changed files with 15 additions and 2 deletions
|
@ -322,7 +322,7 @@ static void process_logits(
|
|||
}
|
||||
}
|
||||
|
||||
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl) {
|
||||
static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) {
|
||||
|
||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
@ -335,6 +335,15 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
|
|||
auto tim2 = std::chrono::high_resolution_clock::now();
|
||||
fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count());
|
||||
|
||||
if (from_chunk > 0) {
|
||||
if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) {
|
||||
fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk);
|
||||
return false;
|
||||
}
|
||||
fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx);
|
||||
tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx);
|
||||
}
|
||||
|
||||
if (int(tokens.size()) < 2*n_ctx) {
|
||||
fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx,
|
||||
n_ctx);
|
||||
|
@ -458,6 +467,7 @@ int main(int argc, char ** argv) {
|
|||
std::string prev_result_file;
|
||||
std::string combine_files;
|
||||
bool compute_ppl = true;
|
||||
int from_chunk = 0;
|
||||
std::vector<char*> args;
|
||||
args.push_back(argv[0]);
|
||||
int iarg = 1;
|
||||
|
@ -482,6 +492,9 @@ int main(int argc, char ** argv) {
|
|||
prev_result_file = argv[++iarg];
|
||||
} else if (arg == "--combine") {
|
||||
combine_files = argv[++iarg];
|
||||
}
|
||||
else if (arg == "--from-chunk") {
|
||||
from_chunk = std::stoi(argv[++iarg]);
|
||||
} else {
|
||||
args.push_back(argv[iarg]);
|
||||
}
|
||||
|
@ -590,7 +603,7 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "%s\n", get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
bool OK = compute_imatrix(ctx, params, compute_ppl);
|
||||
bool OK = compute_imatrix(ctx, params, compute_ppl, from_chunk);
|
||||
if (!OK) {
|
||||
return 1;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue