DRY: Fixed crash issue due to DRY being in chain but uninitialized
This commit is contained in:
parent
13038930af
commit
dc408bba7d
1 changed files with 8 additions and 16 deletions
|
@ -1747,7 +1747,7 @@ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/
|
|||
|
||||
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
|
||||
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||
if (ctx->dry_penalty_last_n == 0) {
|
||||
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1758,7 +1758,7 @@ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token to
|
|||
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||
|
||||
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f) {
|
||||
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1999,18 +1999,15 @@ static struct llama_sampler_i llama_sampler_dry_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||
if (dry_multiplier == 0.0f || dry_base < 1.0f) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
||||
|
||||
// Process sequence breakers
|
||||
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
||||
const int MAX_CHAR_LEN = 40;
|
||||
const int MAX_SEQ_LEN = 20;
|
||||
|
||||
if (seq_breakers != nullptr && num_breakers > 0) {
|
||||
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
||||
|
||||
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
||||
// Process sequence breakers
|
||||
for (size_t i = 0; i < num_breakers; ++i) {
|
||||
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
|
||||
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
|
||||
|
@ -2041,9 +2038,9 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
|
|||
/* .dry_allowed_length = */ dry_allowed_length,
|
||||
/* .dry_penalty_last_n = */ dry_penalty_last_n,
|
||||
/* .dry_processed_breakers = */ std::move(processed_breakers),
|
||||
/* .dry_repeat_count = */ std::vector<int>(effective_dry_penalty_last_n, 0),
|
||||
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
||||
/* .dry_max_token_repeat = */ {},
|
||||
/* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
|
||||
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
@ -2052,11 +2049,6 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
|
|||
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
||||
llama_vocab dummy_vocab;
|
||||
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
||||
|
||||
if (result == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto * ctx = (llama_sampler_dry *) result->ctx;
|
||||
|
||||
// Process the token-based sequence breakers
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue