llama : add llama_sampler_init for safe usage of llama_sampler_free (#11727)

The C API in llama.h claims users can implement `llama_sampler_i` to
create custom `llama_sampler`. The sampler chain takes ownership and
calls `llama_sampler_free` on them. However, `llama_sampler_free` is
hard-coded to use `delete`. This is undefined behavior if the object
wasn't also allocated via `new` from libllama's C++ runtime. Callers
in C and C-compatible languages do not use C++'s `new` operator. C++
callers may not be sharing the same heap as libllama.
This commit is contained in:
Christian Fillion 2025-02-07 04:33:27 -05:00 committed by GitHub
parent ec3bc8270b
commit 7ee953a64a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 70 additions and 62 deletions

View file

@ -254,10 +254,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
}; };
} }
return new llama_sampler{ return llama_sampler_init(
/* .iface = */ &llama_sampler_llg_i, /* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx, /* .ctx = */ ctx
}; );
} }
#else #else

View file

@ -1114,11 +1114,12 @@ extern "C" {
}; };
struct llama_sampler { struct llama_sampler {
struct llama_sampler_i * iface; const struct llama_sampler_i * iface;
llama_sampler_context_t ctx; llama_sampler_context_t ctx;
}; };
// mirror of llama_sampler_i: // mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);

View file

@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
// llama_sampler API // llama_sampler API
struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
return new llama_sampler {
/* .iface = */ iface,
/* .ctx = */ ctx,
};
}
const char * llama_sampler_name(const struct llama_sampler * smpl) { const char * llama_sampler_name(const struct llama_sampler * smpl) {
if (!smpl->iface) { if (!smpl->iface) {
return "(null)"; return "(null)";
@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
} }
if (smpl->ctx == nullptr) { if (smpl->ctx == nullptr) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ smpl->iface, /* .iface = */ smpl->iface,
/* .ctx = */ nullptr, /* .ctx = */ nullptr
}; );
} }
GGML_ABORT("the sampler does not support cloning"); GGML_ABORT("the sampler does not support cloning");
@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
}; };
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_chain_i, /* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain { /* .ctx = */ new llama_sampler_chain {
/* .params = */ params, /* .params = */ params,
/* .samplers = */ {}, /* .samplers = */ {},
/* .t_sample_us = */ 0, /* .t_sample_us = */ 0,
/* .n_sample = */ 0, /* .n_sample = */ 0,
}, }
}; );
} }
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
}; };
struct llama_sampler * llama_sampler_init_greedy() { struct llama_sampler * llama_sampler_init_greedy() {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_greedy_i, /* .iface = */ &llama_sampler_greedy_i,
/* .ctx = */ nullptr, /* .ctx = */ nullptr
}; );
} }
// dist // dist
@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed); auto seed_cur = get_rng_seed(seed);
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i, /* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist { /* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed, /* .seed = */ seed,
/* .seed_cur = */ seed_cur, /* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur), /* .rng = */ std::mt19937(seed_cur),
}, }
}; );
} }
// softmax // softmax
@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
}; };
struct llama_sampler * llama_sampler_init_softmax() { struct llama_sampler * llama_sampler_init_softmax() {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_softmax_i, /* .iface = */ &llama_sampler_softmax_i,
/* .ctx = */ nullptr, /* .ctx = */ nullptr
}; );
} }
// top-k // top-k
@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
}; };
struct llama_sampler * llama_sampler_init_top_k(int32_t k) { struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_top_k_i, /* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k { /* .ctx = */ new llama_sampler_top_k {
/* .k = */ k, /* .k = */ k,
}, }
}; );
} }
// top-p // top-p
@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
}; };
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_top_p_i, /* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_top_p { /* .ctx = */ new llama_sampler_top_p {
/* .p = */ p, /* .p = */ p,
/* .min_keep = */ min_keep, /* .min_keep = */ min_keep,
}, }
}; );
} }
// min-p // min-p
@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
}; };
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_min_p_i, /* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_min_p { /* .ctx = */ new llama_sampler_min_p {
/* .p = */ p, /* .p = */ p,
/* .min_keep = */ min_keep, /* .min_keep = */ min_keep,
}, }
}; );
} }
// typical // typical
@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
}; };
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_typical_i, /* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical { /* .ctx = */ new llama_sampler_typical {
/* .p = */ p, /* .p = */ p,
/* .min_keep = */ min_keep, /* .min_keep = */ min_keep,
}, }
}; );
} }
// temp // temp
@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
}; };
struct llama_sampler * llama_sampler_init_temp(float temp) { struct llama_sampler * llama_sampler_init_temp(float temp) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_i, /* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_temp { /* .ctx = */ new llama_sampler_temp {
/*.temp = */ temp, /*.temp = */ temp,
}, }
}; );
} }
// temp-ext // temp-ext
@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
}; };
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i, /* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext { /* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp, /* .temp = */ temp,
/* .delta = */ delta, /* .delta = */ delta,
/* .exponent = */ exponent, /* .exponent = */ exponent,
}, }
}; );
} }
// xtc // xtc
@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed); auto seed_cur = get_rng_seed(seed);
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_xtc_i, /* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc { /* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p, /* .probability = */ p,
@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
/* .seed = */ seed, /* .seed = */ seed,
/* .seed_cur = */ seed_cur, /* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur), /* .rng = */ std::mt19937(seed_cur),
}, }
}; );
} }
// mirostat // mirostat
@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed); auto seed_cur = get_rng_seed(seed);
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_i, /* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat { /* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab, /* .n_vocab = */ n_vocab,
@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
/* .m = */ m, /* .m = */ m,
/* .mu = */ 2.0f*tau, /* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur), /* .rng = */ std::mt19937(seed_cur),
}, }
}; );
} }
// mirostat v2 // mirostat v2
@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed); auto seed_cur = get_rng_seed(seed);
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_v2_i, /* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 { /* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed, /* .seed = */ seed,
@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
/* .eta = */ eta, /* .eta = */ eta,
/* .mu = */ 2.0f*tau, /* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur), /* .rng = */ std::mt19937(seed_cur),
}, }
}; );
} }
// grammar // grammar
@ -1528,10 +1535,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
}; };
} }
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_grammar_i, /* .iface = */ &llama_sampler_grammar_i,
/* .ctx = */ ctx, /* .ctx = */ ctx
}; );
} }
struct llama_sampler * llama_sampler_init_grammar( struct llama_sampler * llama_sampler_init_grammar(
@ -1678,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_present) { float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0); penalty_last_n = std::max(penalty_last_n, 0);
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_penalties_i, /* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties { /* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n, /* .penalty_last_n = */ penalty_last_n,
@ -1687,8 +1694,8 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalty_present = */ penalty_present, /* .penalty_present = */ penalty_present,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n), /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {}, /* .token_count = */ {},
}, }
}; );
} }
// DRY // DRY
@ -2041,7 +2048,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
} }
} }
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i, /* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry { /* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ context_size, /* .total_context_size = */ context_size,
@ -2053,8 +2060,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{}, /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {}, /* .dry_max_token_repeat = */ {},
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0), /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
}, }
}; );
} }
// wrapper for test-sampling.cpp // wrapper for test-sampling.cpp
@ -2155,14 +2162,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab, int32_t n_vocab,
int32_t n_logit_bias, int32_t n_logit_bias,
const llama_logit_bias * logit_bias) { const llama_logit_bias * logit_bias) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_logit_bias_i, /* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias { /* .ctx = */ new llama_sampler_logit_bias {
/* .n_vocab = */ n_vocab, /* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias), /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {}, /* .to_search = */ {},
}, }
}; );
} }
// infill // infill
@ -2377,14 +2384,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
}; };
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
return new llama_sampler { return llama_sampler_init(
/* .iface = */ &llama_sampler_infill_i, /* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill { /* .ctx = */ new llama_sampler_infill {
/* .vocab = */ vocab, /* .vocab = */ vocab,
/* .buf0 = */ std::vector<char>(512), /* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512), /* .buf1 = */ std::vector<char>(512),
}, }
}; );
} }
// utils // utils