diff --git a/common/sampling.cpp b/common/sampling.cpp index f400aa7fb..a197011de 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -337,7 +337,7 @@ static llama_token_data_array llama_sampling_prepare_impl( { const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n); if (penalty_tokens_used_size) { - llama_sample_dry(ctx_main, &cur_p, + llama_sample_dry(&cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); diff --git a/llama.cpp b/llama.cpp index 709baee63..a4cb1f284 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13233,7 +13233,7 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can } } -void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) { +void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, int dry_seq_breakers_size) { // skip dry sampler if we don't have a previous token if (last_tokens_size < 1) return; diff --git a/llama.h b/llama.h index cb58daea0..774c1b222 100644 --- a/llama.h +++ b/llama.h @@ -926,7 +926,6 @@ extern "C" { /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 LLAMA_API void llama_sample_dry( - struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, int last_tokens_size,