perplexity : ignore n_batch, submit whole chunk in one call
This commit is contained in:
parent
0068da7fef
commit
e264f2239e
2 changed files with 44 additions and 48 deletions
|
@ -189,19 +189,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
|
|
||||||
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
const int n_batch = params.n_batch;
|
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
double nll = 0.0;
|
double nll = 0.0;
|
||||||
|
|
||||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
fprintf(stderr, "%s: calculating perplexity over %d chunks\n", __func__, n_chunk);
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
const int start = i * params.ppl_stride;
|
const int start = i * params.ppl_stride;
|
||||||
const int end = start + calc_chunk;
|
//const int end = start + calc_chunk;
|
||||||
|
|
||||||
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
|
|
||||||
//fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches);
|
|
||||||
|
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
|
||||||
|
@ -210,32 +206,26 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
|
||||||
const int batch_start = start + j * n_batch;
|
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
|
||||||
|
|
||||||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + start, calc_chunk, 0, 0))) {
|
||||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
|
||||||
|
|
||||||
// save original token and restore it after eval
|
|
||||||
const auto token_org = tokens[batch_start];
|
|
||||||
|
|
||||||
// add BOS token for the first batch of each chunk
|
|
||||||
if (add_bos && j == 0) {
|
|
||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto batch_logits = llama_get_logits(ctx);
|
|
||||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
|
||||||
|
|
||||||
if (j == 0) {
|
|
||||||
tokens[batch_start] = token_org;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// save original token and restore it after eval
|
||||||
|
const auto token_org = tokens[start];
|
||||||
|
|
||||||
|
// add BOS token for the first batch of each chunk
|
||||||
|
if (add_bos) {
|
||||||
|
tokens[start] = llama_token_bos(llama_get_model(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto * batch_logits = llama_get_logits(ctx);
|
||||||
|
logits.insert(logits.end(), batch_logits, batch_logits + calc_chunk * n_vocab);
|
||||||
|
|
||||||
|
tokens[start] = token_org;
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
|
@ -246,7 +236,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
||||||
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
||||||
total_seconds = total_seconds % (60*60);
|
total_seconds = total_seconds % (60*60);
|
||||||
}
|
}
|
||||||
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
fprintf(stderr, "%.2f minutes ", total_seconds / 60.0);
|
||||||
|
fprintf(stderr, "(%.2f t/s)\n", n_ctx/t_total);
|
||||||
}
|
}
|
||||||
|
|
||||||
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
|
//fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start);
|
||||||
|
@ -327,9 +318,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
const int start = i * n_ctx;
|
const int start = i * n_ctx;
|
||||||
const int end = start + n_ctx;
|
//const int end = start + n_ctx;
|
||||||
|
|
||||||
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
//const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||||
|
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
|
|
||||||
|
@ -338,33 +329,33 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
//for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
// const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
// const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
// save original token and restore it after eval
|
// save original token and restore it after eval
|
||||||
const auto token_org = tokens[batch_start];
|
const auto token_org = tokens[start];
|
||||||
|
|
||||||
// add BOS token for the first batch of each chunk
|
// add BOS token for the first batch of each chunk
|
||||||
if (add_bos && j == 0) {
|
if (add_bos) {
|
||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
tokens[start] = llama_token_bos(llama_get_model(ctx));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + start, n_ctx, 0, 0))) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
// restore the original token in case it was set to BOS
|
// restore the original token in case it was set to BOS
|
||||||
tokens[batch_start] = token_org;
|
tokens[start] = token_org;
|
||||||
|
|
||||||
const auto * batch_logits = llama_get_logits(ctx);
|
const auto * batch_logits = llama_get_logits(ctx);
|
||||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
logits.insert(logits.end(), batch_logits, batch_logits + n_ctx * n_vocab);
|
||||||
}
|
//}
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 1) { // TODO: skipping the first chunk gives a better estimate, but breaks formatting
|
||||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||||
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||||
int total_seconds = (int)(t_total * n_chunk);
|
int total_seconds = (int)(t_total * n_chunk);
|
||||||
|
@ -372,7 +363,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
||||||
total_seconds = total_seconds % (60*60);
|
total_seconds = total_seconds % (60*60);
|
||||||
}
|
}
|
||||||
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
fprintf(stderr, "%.2f minutes ", total_seconds / 60.0);
|
||||||
|
fprintf(stderr, "(%.2f t/s)\n", n_ctx/t_total);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// We get the logits for all the tokens in the context window (params.n_ctx)
|
// We get the logits for all the tokens in the context window (params.n_ctx)
|
||||||
|
@ -433,7 +426,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto logits = llama_get_logits(ctx);
|
const auto * logits = llama_get_logits(ctx);
|
||||||
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
|
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
|
||||||
|
|
||||||
n_past += n_tokens;
|
n_past += n_tokens;
|
||||||
|
@ -678,13 +671,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
params.n_batch = 512;
|
//params.n_batch = 512;
|
||||||
if (!gpt_params_parse(argc, argv, params)) {
|
if (!gpt_params_parse(argc, argv, params)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
params.logits_all = true;
|
params.logits_all = true;
|
||||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
//params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||||
|
|
||||||
if (params.ppl_stride > 0) {
|
if (params.ppl_stride > 0) {
|
||||||
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
|
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
|
||||||
|
|
|
@ -6222,7 +6222,7 @@ static int llama_decode_internal(
|
||||||
logits_valid.clear();
|
logits_valid.clear();
|
||||||
logits_valid.resize(n_tokens_all);
|
logits_valid.resize(n_tokens_all);
|
||||||
|
|
||||||
logits_out.clear();
|
memset(logits_out, 0, lctx.logits_size*sizeof(float));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
@ -6428,6 +6428,9 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//ggml_backend_sched_synchronize(lctx.sched);
|
||||||
|
//lctx.buf_cpu_ub_cur = 0;
|
||||||
|
|
||||||
// measure the performance only for the single-token evals
|
// measure the performance only for the single-token evals
|
||||||
if (n_tokens_all == 1) {
|
if (n_tokens_all == 1) {
|
||||||
lctx.t_eval_us += ggml_time_us() - t_start_us;
|
lctx.t_eval_us += ggml_time_us() - t_start_us;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue