DRY: Fixed crash issue due to DRY being in chain but uninitialized

This commit is contained in:
wwoodsTM 2024-10-24 01:15:04 -06:00
parent 13038930af
commit dc408bba7d

View file

@ -1747,7 +1747,7 @@ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_dry *) smpl->ctx;
if (ctx->dry_penalty_last_n == 0) {
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
return;
}
@ -1758,7 +1758,7 @@ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token to
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_dry *) smpl->ctx;
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f) {
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
return;
}
@ -1999,18 +1999,15 @@ static struct llama_sampler_i llama_sampler_dry_i = {
};
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
if (dry_multiplier == 0.0f || dry_base < 1.0f) {
return nullptr;
}
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
// Process sequence breakers
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
const int MAX_CHAR_LEN = 40;
const int MAX_SEQ_LEN = 20;
if (seq_breakers != nullptr && num_breakers > 0) {
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
// Process sequence breakers
for (size_t i = 0; i < num_breakers; ++i) {
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
@ -2041,9 +2038,9 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
/* .dry_allowed_length = */ dry_allowed_length,
/* .dry_penalty_last_n = */ dry_penalty_last_n,
/* .dry_processed_breakers = */ std::move(processed_breakers),
/* .dry_repeat_count = */ std::vector<int>(effective_dry_penalty_last_n, 0),
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {},
/* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
},
};
}
@ -2052,11 +2049,6 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
llama_vocab dummy_vocab;
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
if (result == nullptr) {
return nullptr;
}
auto * ctx = (llama_sampler_dry *) result->ctx;
// Process the token-based sequence breakers