diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index bc9f6fa68..4f8e52986 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -323,139 +323,87 @@ static void process_logits( } static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool compute_ppl, int from_chunk) { + (void)from_chunk; - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - const int n_ctx = llama_n_ctx(ctx); + std::vector workers(std::thread::hardware_concurrency() - 1); - auto tim1 = std::chrono::high_resolution_clock::now(); - fprintf(stderr, "%s: tokenizing the input ..\n", __func__); - - std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); - - auto tim2 = std::chrono::high_resolution_clock::now(); - fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); - - if (from_chunk > 0) { - if (size_t((from_chunk + 2)*n_ctx) >= tokens.size()) { - fprintf(stderr, "%s: there will be not enough tokens left after removing %d chunks\n", __func__, from_chunk); - return false; - } - fprintf(stderr, "%s: removing initial %d chunks (%d tokens)\n", __func__, from_chunk, from_chunk*n_ctx); - tokens.erase(tokens.begin(), tokens.begin() + from_chunk*n_ctx); - } - - if (int(tokens.size()) < 2*n_ctx) { - fprintf(stderr, "%s: you need at least %d tokens for a context of %d tokens\n",__func__,2*n_ctx, - n_ctx); - fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); - return false; - } - - std::vector logit_history; - std::vector prob_history; - - if (compute_ppl) { - logit_history.resize(tokens.size()); - prob_history.resize(tokens.size()); - } - - const int n_chunk_max = tokens.size() / n_ctx; - - const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_batch = params.n_batch; int count = 0; double nll = 0.0; double nll2 = 0.0; - fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch); + std::vector tokens; + std::vector logit_history; + std::vector prob_history; - std::vector workers(std::thread::hardware_concurrency() - 1); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - const int num_batches = (n_ctx + n_batch - 1) / n_batch; + size_t c_begin = 0; + while (true) { + const char* s_begin = "<|im_start|>system\n"; + const char* s_assistant = "<|im_start|>assistant\n"; + c_begin = params.prompt.find(s_begin, c_begin); + if (c_begin == std::string::npos) { + break; + } + size_t c_assistant = params.prompt.find(s_assistant, c_begin); + if (c_assistant == std::string::npos) { + break; + } + c_assistant += strlen(s_assistant); + size_t next_c_begin = params.prompt.find(s_begin, c_assistant); + auto s_prompt = params.prompt.substr(c_begin, c_assistant - c_begin); + auto s_response = params.prompt.substr(c_assistant, next_c_begin - c_assistant); + c_begin += 1; - std::vector logits; - if (compute_ppl && num_batches > 1) { - logits.reserve((size_t)n_ctx * n_vocab); - } - - for (int i = 0; i < n_chunk; ++i) { - const int start = i * n_ctx; - const int end = start + n_ctx; - - std::vector logits; - - const auto t_start = std::chrono::high_resolution_clock::now(); - - // clear the KV cache llama_kv_cache_clear(ctx); - for (int j = 0; j < num_batches; ++j) { - const int batch_start = start + j * n_batch; - const int batch_size = std::min(end - batch_start, n_batch); + std::vector s_tokens_prompt = ::llama_tokenize(ctx, s_prompt, false); + std::vector s_tokens_response = ::llama_tokenize(ctx, s_response, false); + std::vector s_tokens = s_tokens_prompt; + s_tokens.insert(s_tokens.end(), s_tokens_response.begin(), s_tokens_response.end()); + std::vector s_logits; + std::vector s_logit_history(s_tokens.size(), 0); + std::vector s_prob_history(s_tokens.size(), 0); - // save original token and restore it after eval - const auto token_org = tokens[batch_start]; + for (int j = 0; j < (int(s_tokens.size()) + n_batch - 1) / n_batch; ++j) { + const int batch_start = j * n_batch; + const int batch_size = std::min((int)s_tokens.size() - batch_start, n_batch); - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); - } - - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + if (llama_decode(ctx, llama_batch_get_one(s_tokens.data() + batch_start, batch_size, j * n_batch, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } - // restore the original token in case it was set to BOS - tokens[batch_start] = token_org; - - if (compute_ppl && num_batches > 1) { - const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); - } + const auto* batch_logits = llama_get_logits(ctx); + s_logits.insert(s_logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } - const auto t_end = std::chrono::high_resolution_clock::now(); + const int first = s_tokens_prompt.size(); + const float* all_logits = s_logits.data(); + process_logits(n_vocab, all_logits + first * n_vocab, s_tokens.data() + first, s_tokens_response.size() - 1, + workers, nll, nll2, s_logit_history.data() + first, s_prob_history.data() + first); + count += s_tokens_response.size() - 1; - if (i == 0) { - const float t_total = std::chrono::duration(t_end - t_start).count(); - fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); - int total_seconds = (int)(t_total * n_chunk); - if (total_seconds >= 60*60) { - fprintf(stderr, "%d hours ", total_seconds / (60*60)); - total_seconds = total_seconds % (60*60); - } - fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); - } + printf(" %.4lf,", std::exp(nll / count)); + fflush(stdout); - if (compute_ppl) { - const int first = n_ctx/2; - const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); - process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); - count += n_ctx - first - 1; - - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); - fflush(stdout); - - logits.clear(); - } + tokens.insert(tokens.end(), s_tokens.begin(), s_tokens.end()); + logit_history.insert(logit_history.end(), s_logit_history.begin(), s_logit_history.end()); + prob_history.insert(prob_history.end(), s_prob_history.begin(), s_prob_history.end()); } - printf("\n"); - if (compute_ppl) { - nll2 /= count; - nll /= count; - const double ppl = exp(nll); - nll2 -= nll * nll; - if (nll2 > 0) { - nll2 = sqrt(nll2/(count-1)); - printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); - } else { - printf("Unexpected negative standard deviation of log(prob)\n"); - } + nll2 /= count; + nll /= count; + const double ppl = exp(nll); + nll2 -= nll * nll; + if (nll2 > 0) { + nll2 = sqrt(nll2 / (count - 1)); + printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2 * ppl); + } + else { + printf("Unexpected negative standard deviation of log(prob)\n"); } return true; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index b2c131d4c..29a25b941 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & return {tokens, std::exp(nll / count), logit_history, prob_history}; } -static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) { +static results_perplexity perplexity(llama_context* ctx, const gpt_params& params) { if (params.ppl_stride > 0) { return perplexity_v2(ctx, params); } @@ -452,170 +452,88 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - const int n_ctx = llama_n_ctx(ctx); + std::vector workers(std::thread::hardware_concurrency() - 1); - std::ofstream logits_stream; - if (!params.logits_file.empty()) { - logits_stream.open(params.logits_file.c_str(), std::ios::binary); - if (!logits_stream.is_open()) { - fprintf(stderr, "%s: failed to open %s for writing\n", __func__, params.logits_file.c_str()); - return {}; - } - fprintf(stderr, "%s: saving all logits to %s\n", __func__, params.logits_file.c_str()); - logits_stream.write("_logits_", 8); - logits_stream.write(reinterpret_cast(&n_ctx), sizeof(n_ctx)); - } - - auto tim1 = std::chrono::high_resolution_clock::now(); - fprintf(stderr, "%s: tokenizing the input ..\n", __func__); - - std::vector tokens = ::llama_tokenize(ctx, params.prompt, add_bos); - - auto tim2 = std::chrono::high_resolution_clock::now(); - fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count()); - - if (int(tokens.size()) < 2*n_ctx) { - fprintf(stderr, "%s: you need at least %d tokens to evaluate perplexity with a context of %d\n",__func__,2*n_ctx, - n_ctx); - fprintf(stderr, "%s: the data file you provided tokenizes to only %zu tokens\n",__func__,tokens.size()); - return {std::move(tokens), 0., {}, {}}; - } - - std::vector logit_history; - logit_history.resize(tokens.size()); - - std::vector prob_history; - prob_history.resize(tokens.size()); - - const int n_chunk_max = tokens.size() / n_ctx; - - const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_batch = params.n_batch; int count = 0; double nll = 0.0; double nll2 = 0.0; - const int num_batches = (n_ctx + n_batch - 1) / n_batch; + std::vector tokens; + std::vector logit_history; + std::vector prob_history; - std::vector logits; - if (num_batches > 1) { - logits.reserve((size_t)n_ctx * n_vocab); - } + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + size_t c_begin = 0; + while (true) { + const char* s_begin = "<|im_start|>system\n"; + const char* s_assistant = "<|im_start|>assistant\n"; + c_begin = params.prompt.find(s_begin, c_begin); + if (c_begin == std::string::npos) { + break; + } + size_t c_assistant = params.prompt.find(s_assistant, c_begin); + if (c_assistant == std::string::npos) { + break; + } + c_assistant += strlen(s_assistant); + size_t next_c_begin = params.prompt.find(s_begin, c_assistant); + auto s_prompt = params.prompt.substr(c_begin, c_assistant - c_begin); + auto s_response = params.prompt.substr(c_assistant, next_c_begin - c_assistant); + c_begin += 1; - std::vector workers(std::thread::hardware_concurrency() - 1); - - std::vector log_probs; - if (!params.logits_file.empty()) { - logits_stream.write((const char *)&n_vocab, sizeof(n_vocab)); - logits_stream.write((const char *)&n_chunk, sizeof(n_chunk)); - logits_stream.write((const char *)tokens.data(), n_chunk*n_ctx*sizeof(tokens[0])); - const int nv = 2*((n_vocab + 1)/2) + 4; - log_probs.resize(n_ctx * nv); - } - - for (int i = 0; i < n_chunk; ++i) { - const int start = i * n_ctx; - const int end = start + n_ctx; - - const auto t_start = std::chrono::high_resolution_clock::now(); - - // clear the KV cache llama_kv_cache_clear(ctx); - for (int j = 0; j < num_batches; ++j) { - const int batch_start = start + j * n_batch; - const int batch_size = std::min(end - batch_start, n_batch); + std::vector s_tokens_prompt = ::llama_tokenize(ctx, s_prompt, false); + std::vector s_tokens_response = ::llama_tokenize(ctx, s_response, false); + std::vector s_tokens = s_tokens_prompt; + s_tokens.insert(s_tokens.end(), s_tokens_response.begin(), s_tokens_response.end()); + std::vector s_logits; + std::vector s_logit_history(s_tokens.size(), 0); + std::vector s_prob_history(s_tokens.size(), 0); - // save original token and restore it after eval - const auto token_org = tokens[batch_start]; + for (int j = 0; j < (int(s_tokens.size()) + n_batch - 1) / n_batch; ++j) { + const int batch_start = j * n_batch; + const int batch_size = std::min((int)s_tokens.size() - batch_start, n_batch); - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); - } - - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + if (llama_decode(ctx, llama_batch_get_one(s_tokens.data() + batch_start, batch_size, j * n_batch, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); - return {tokens, -1, logit_history, prob_history}; + return { tokens, -1, logit_history, prob_history }; } - // restore the original token in case it was set to BOS - tokens[batch_start] = token_org; - - if (num_batches > 1) { - const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); - } + const auto* batch_logits = llama_get_logits(ctx); + s_logits.insert(s_logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } - const auto t_end = std::chrono::high_resolution_clock::now(); + const int first = s_tokens_prompt.size(); + const float* all_logits = s_logits.data(); + process_logits(n_vocab, all_logits + first * n_vocab, s_tokens.data() + first, s_tokens_response.size() - 1, + workers, nll, nll2, s_logit_history.data() + first, s_prob_history.data() + first); + count += s_tokens_response.size() - 1; - if (i == 0) { - const float t_total = std::chrono::duration(t_end - t_start).count(); - fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); - int total_seconds = (int)(t_total * n_chunk); - if (total_seconds >= 60*60) { - fprintf(stderr, "%d hours ", total_seconds / (60*60)); - total_seconds = total_seconds % (60*60); - } - fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); - } - - // We get the logits for all the tokens in the context window (params.n_ctx) - // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, - // calculate the perplexity over the last half of the window (so the model always has - // some context to predict the token). - // - // We rely on the fact that attention in the forward pass only looks at previous - // tokens here, so the logits returned for each token are an accurate representation - // of what the model would have predicted at that point. - // - // Example, we have a context window of 512, we will compute perplexity for each of the - // last 256 tokens. Then, we split the input up into context window size chunks to - // process the entire prompt. - const int first = n_ctx/2; - const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); - if (!params.logits_file.empty()) { - process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, log_probs, nll, nll2); - } else { - process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); - } - count += n_ctx - first - 1; - - // perplexity is e^(average negative log-likelihood) - if (params.ppl_output_type == 0) { - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); - } else { - double av = nll/count; - double av2 = nll2/count - av*av; - if (av2 > 0) av2 = sqrt(av2/(count-1)); - printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); - } + printf(" %.4lf,", std::exp(nll / count)); fflush(stdout); - logits.clear(); + tokens.insert(tokens.end(), s_tokens.begin(), s_tokens.end()); + logit_history.insert(logit_history.end(), s_logit_history.begin(), s_logit_history.end()); + prob_history.insert(prob_history.end(), s_prob_history.begin(), s_prob_history.end()); } - printf("\n"); nll2 /= count; nll /= count; const double ppl = exp(nll); nll2 -= nll * nll; if (nll2 > 0) { - nll2 = sqrt(nll2/(count-1)); - printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); - } else { + nll2 = sqrt(nll2 / (count - 1)); + printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2 * ppl); + } + else { printf("Unexpected negative standard deviation of log(prob)\n"); } - return {tokens, ppl, logit_history, prob_history}; + return { tokens, ppl, logit_history, prob_history }; } static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int32_t n_batch, int32_t n_vocab) {