examplse : de-shadow

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-12 14:25:32 +02:00
parent 82caffa74e
commit 9a735ae6d8
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
16 changed files with 152 additions and 159 deletions

View file

@ -348,8 +348,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.ppl_stride;
for (int ich = 0; ich < n_chunk; ++ich) {
const int start = ich * params.ppl_stride;
const int end = start + calc_chunk;
const int num_batches = (calc_chunk + n_batch - 1) / n_batch;
@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
if (ich == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
@ -427,9 +427,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
}
// perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) {
LOG("[%d]%.4lf,", i + 1, std::exp(nll / count));
LOG("[%d]%.4lf,", ich + 1, std::exp(nll / count));
} else {
LOG("%8d %.4lf\n", i*params.ppl_stride, std::exp(nll / count));
LOG("%8d %.4lf\n", ich*params.ppl_stride, std::exp(nll / count));
}
}
LOG("\n");
@ -659,7 +659,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
int prev_outputs = 0;
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
for (int i = 0; i < batch.n_tokens; i += n_batch) {
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
llama_batch batch_view = {
@ -679,8 +679,8 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
}
int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
for (int iv = 0; iv < n_tokens; ++iv) {
n_outputs += batch_view.logits[iv] != 0;
}
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@ -1752,14 +1752,14 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
auto kld_ptr = kld_values.data();
auto p_diff_ptr = p_diff_values.data();
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
for (int ich = 0; ich < n_chunk; ++ich) {
const int start = ich * n_ctx;
const int end = start + n_ctx;
const auto t_start = std::chrono::high_resolution_clock::now();
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, ich);
return;
}
@ -1804,7 +1804,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
if (ich == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
@ -1824,7 +1824,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
p_diff_ptr += n_ctx - 1 - first;
kld_ptr += n_ctx - 1 - first;
LOG("%4d", i+1);
LOG("%4d", ich + 1);
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
const double ppl_val = exp(log_ppl.first);