llama : update llama_kv_self API
ggml-ci
This commit is contained in:
parent
fd05ab87aa
commit
17b363afd3
30 changed files with 387 additions and 205 deletions
|
@ -299,8 +299,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_kv_cache * kv = llama_get_kv_cache(ctx);
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
|
||||
|
||||
|
@ -362,7 +360,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_clear(kv);
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
|
@ -452,8 +450,6 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_kv_cache * kv = llama_get_kv_cache(ctx);
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
|
||||
|
||||
|
@ -550,7 +546,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_clear(kv);
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
|
@ -745,8 +741,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_kv_cache * kv = llama_get_kv_cache(ctx);
|
||||
|
||||
// Calculates hellaswag score (acc_norm) from prompt
|
||||
//
|
||||
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
|
||||
|
@ -929,7 +923,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||
return;
|
||||
}
|
||||
|
||||
llama_kv_cache_clear(kv);
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
// decode all tasks [i0, i1)
|
||||
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||
|
@ -1090,8 +1084,6 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
|||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_kv_cache * kv = llama_get_kv_cache(ctx);
|
||||
|
||||
constexpr int k_min_trailing_ctx = 3;
|
||||
|
||||
auto data = load_winogrande_from_csv(params.prompt);
|
||||
|
@ -1210,7 +1202,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
|||
return;
|
||||
}
|
||||
|
||||
llama_kv_cache_clear(kv);
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
// decode all tasks [i0, i1)
|
||||
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||
|
@ -1396,8 +1388,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_kv_cache * kv = llama_get_kv_cache(ctx);
|
||||
|
||||
std::istringstream strstream(params.prompt);
|
||||
uint32_t n_task;
|
||||
strstream.read((char *)&n_task, sizeof(n_task));
|
||||
|
@ -1584,7 +1574,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||
return;
|
||||
}
|
||||
|
||||
llama_kv_cache_clear(kv);
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
// decode all tasks [i0, i1)
|
||||
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||
|
@ -1681,8 +1671,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_kv_cache * kv = llama_get_kv_cache(ctx);
|
||||
|
||||
if (params.logits_file.empty()) {
|
||||
LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
|
||||
return;
|
||||
|
@ -1776,7 +1764,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||
}
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_clear(kv);
|
||||
llama_kv_self_clear(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue