DRY: Fixed crash issue due to DRY being in chain but uninitialized
This commit is contained in:
parent
13038930af
commit
dc408bba7d
1 changed files with 8 additions and 16 deletions
|
@ -1747,7 +1747,7 @@ static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/
|
||||||
|
|
||||||
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
|
static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) {
|
||||||
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
if (ctx->dry_penalty_last_n == 0) {
|
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1758,7 +1758,7 @@ static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token to
|
||||||
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||||
|
|
||||||
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f) {
|
if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1999,18 +1999,15 @@ static struct llama_sampler_i llama_sampler_dry_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||||
if (dry_multiplier == 0.0f || dry_base < 1.0f) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
|
||||||
|
|
||||||
// Process sequence breakers
|
|
||||||
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
|
||||||
const int MAX_CHAR_LEN = 40;
|
const int MAX_CHAR_LEN = 40;
|
||||||
const int MAX_SEQ_LEN = 20;
|
const int MAX_SEQ_LEN = 20;
|
||||||
|
|
||||||
if (seq_breakers != nullptr && num_breakers > 0) {
|
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
||||||
|
|
||||||
|
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
||||||
|
// Process sequence breakers
|
||||||
for (size_t i = 0; i < num_breakers; ++i) {
|
for (size_t i = 0; i < num_breakers; ++i) {
|
||||||
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
|
if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
|
||||||
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
|
LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
|
||||||
|
@ -2041,9 +2038,9 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
|
||||||
/* .dry_allowed_length = */ dry_allowed_length,
|
/* .dry_allowed_length = */ dry_allowed_length,
|
||||||
/* .dry_penalty_last_n = */ dry_penalty_last_n,
|
/* .dry_penalty_last_n = */ dry_penalty_last_n,
|
||||||
/* .dry_processed_breakers = */ std::move(processed_breakers),
|
/* .dry_processed_breakers = */ std::move(processed_breakers),
|
||||||
/* .dry_repeat_count = */ std::vector<int>(effective_dry_penalty_last_n, 0),
|
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
|
||||||
/* .dry_max_token_repeat = */ {},
|
/* .dry_max_token_repeat = */ {},
|
||||||
/* .last_tokens = */ ring_buffer<llama_token>(effective_dry_penalty_last_n),
|
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -2052,11 +2049,6 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
|
||||||
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
|
||||||
llama_vocab dummy_vocab;
|
llama_vocab dummy_vocab;
|
||||||
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
|
||||||
|
|
||||||
if (result == nullptr) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto * ctx = (llama_sampler_dry *) result->ctx;
|
auto * ctx = (llama_sampler_dry *) result->ctx;
|
||||||
|
|
||||||
// Process the token-based sequence breakers
|
// Process the token-based sequence breakers
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue