llama : cont

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-13 14:56:52 +02:00
parent f78b396ee7
commit e4550fbafc
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
19 changed files with 128 additions and 79 deletions

View file

@ -299,6 +299,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
@ -360,7 +362,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_cache_clear(kv);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
@ -450,6 +452,8 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
@ -546,7 +550,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_cache_clear(kv);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@ -741,6 +745,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
// Calculates hellaswag score (acc_norm) from prompt
//
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
@ -923,7 +929,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
return;
}
llama_kv_cache_clear(ctx);
llama_kv_cache_clear(kv);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1084,6 +1090,8 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
constexpr int k_min_trailing_ctx = 3;
auto data = load_winogrande_from_csv(params.prompt);
@ -1202,7 +1210,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
return;
}
llama_kv_cache_clear(ctx);
llama_kv_cache_clear(kv);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1388,6 +1396,8 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
std::istringstream strstream(params.prompt);
uint32_t n_task;
strstream.read((char *)&n_task, sizeof(n_task));
@ -1574,7 +1584,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
return;
}
llama_kv_cache_clear(ctx);
llama_kv_cache_clear(kv);
// decode all tasks [i0, i1)
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
@ -1671,6 +1681,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
if (params.logits_file.empty()) {
LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
return;
@ -1764,7 +1776,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
}
// clear the KV cache
llama_kv_cache_clear(ctx);
llama_kv_cache_clear(kv);
llama_batch batch = llama_batch_init(n_batch, 0, 1);