sampling : simplify new llama_sampler calls

This commit is contained in:
Georgi Gerganov 2024-09-04 20:21:02 +03:00
parent 784a644040
commit e7a11cac0e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -439,12 +439,10 @@ static struct llama_constraint_i llama_constraint_softmax_i = {
}; };
struct llama_constraint * llama_constraint_init_softmax_impl() { struct llama_constraint * llama_constraint_init_softmax_impl() {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_softmax_i, /* .iface = */ &llama_constraint_softmax_i,
/* .ctx = */ nullptr, /* .ctx = */ nullptr,
}; };
return result;
} }
// top-k // top-k
@ -472,15 +470,13 @@ static struct llama_constraint_i llama_constraint_top_k_i = {
}; };
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_top_k_i, /* .iface = */ &llama_constraint_top_k_i,
/* .ctx = */ new llama_constraint_context_top_k { /* .ctx = */ new llama_constraint_context_top_k {
/*.k =*/ k, /*.k =*/ k,
/*.min_keep =*/ min_keep, /*.min_keep =*/ min_keep,
}, },
}; };
return result;
} }
// top-p // top-p
@ -508,15 +504,13 @@ static struct llama_constraint_i llama_constraint_top_p_i = {
}; };
struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) { struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_top_p_i, /* .iface = */ &llama_constraint_top_p_i,
/* .ctx = */ new llama_constraint_context_top_p { /* .ctx = */ new llama_constraint_context_top_p {
/*.p =*/ p, /*.p =*/ p,
/*.min_keep =*/ min_keep, /*.min_keep =*/ min_keep,
}, },
}; };
return result;
} }
// min-p // min-p
@ -544,15 +538,13 @@ static struct llama_constraint_i llama_constraint_min_p_i = {
}; };
struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) { struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_min_p_i, /* .iface = */ &llama_constraint_min_p_i,
/* .ctx = */ new llama_constraint_context_min_p { /* .ctx = */ new llama_constraint_context_min_p {
/*.p =*/ p, /*.p =*/ p,
/*.min_keep =*/ min_keep, /*.min_keep =*/ min_keep,
}, },
}; };
return result;
} }
// tail-free // tail-free
@ -580,15 +572,13 @@ static struct llama_constraint_i llama_constraint_tail_free_i = {
}; };
struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) { struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_tail_free_i, /* .iface = */ &llama_constraint_tail_free_i,
/* .ctx = */ new llama_constraint_context_tail_free { /* .ctx = */ new llama_constraint_context_tail_free {
/*.z =*/ z, /*.z =*/ z,
/*.min_keep =*/ min_keep, /*.min_keep =*/ min_keep,
}, },
}; };
return result;
} }
// typical // typical
@ -616,15 +606,13 @@ static struct llama_constraint_i llama_constraint_typical_i = {
}; };
struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) { struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_typical_i, /* .iface = */ &llama_constraint_typical_i,
/* .ctx = */ new llama_constraint_context_typical { /* .ctx = */ new llama_constraint_context_typical {
/*.p =*/ p, /*.p =*/ p,
/*.min_keep =*/ min_keep, /*.min_keep =*/ min_keep,
}, },
}; };
return result;
} }
// temp // temp
@ -651,14 +639,12 @@ static struct llama_constraint_i llama_constraint_temp_i = {
}; };
struct llama_constraint * llama_constraint_init_temp_impl(float temp) { struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_temp_i, /* .iface = */ &llama_constraint_temp_i,
/* .ctx = */ new llama_constraint_context_temp { /* .ctx = */ new llama_constraint_context_temp {
/*.temp =*/ temp, /*.temp =*/ temp,
}, },
}; };
return result;
} }
// temp-ext // temp-ext
@ -694,7 +680,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = {
}; };
struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) { struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_temp_ext_i, /* .iface = */ &llama_constraint_temp_ext_i,
/* .ctx = */ new llama_constraint_context_temp_ext { /* .ctx = */ new llama_constraint_context_temp_ext {
/*.temp =*/ temp, /*.temp =*/ temp,
@ -702,8 +688,6 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float
/*.exponent =*/ exponent, /*.exponent =*/ exponent,
}, },
}; };
return result;
} }
// mirostat // mirostat
@ -782,12 +766,8 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
}, },
}; };
struct llama_constraint * llama_constraint_init_mirostat_impl( struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) {
const struct llama_vocab & vocab, return new llama_constraint {
float tau,
float eta,
int32_t m) {
struct llama_constraint * result = new llama_constraint {
/* .iface = */ &llama_constraint_mirostat_i, /* .iface = */ &llama_constraint_mirostat_i,
/* .ctx = */ new llama_constraint_context_mirostat { /* .ctx = */ new llama_constraint_context_mirostat {
/*.vocab =*/ &vocab, /*.vocab =*/ &vocab,
@ -798,8 +778,6 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(
/*.cur =*/ {}, /*.cur =*/ {},
}, },
}; };
return result;
} }
// mirostat v2 // mirostat v2
@ -863,7 +841,7 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = {
}; };
struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, float eta) { struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, float eta) {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_mirostat_v2_i, /* .iface = */ &llama_constraint_mirostat_v2_i,
/* .ctx = */ new llama_constraint_context_mirostat_v2 { /* .ctx = */ new llama_constraint_context_mirostat_v2 {
/*.tau =*/ tau, /*.tau =*/ tau,
@ -872,8 +850,6 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa
/*.cur =*/ {}, /*.cur =*/ {},
}, },
}; };
return result;
} }
// grammar // grammar
@ -953,12 +929,10 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
}; };
} }
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_grammar_i, /* .iface = */ &llama_constraint_grammar_i,
/* .ctx = */ ctx, /* .ctx = */ ctx,
}; };
return result;
} }
// penalties // penalties
@ -1039,10 +1013,10 @@ static struct llama_constraint_i llama_constraint_penalties_i = {
}; };
struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) {
GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL);
GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL); GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL);
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_penalties_i, /* .iface = */ &llama_constraint_penalties_i,
/* .ctx = */ new llama_constraint_context_penalties { /* .ctx = */ new llama_constraint_context_penalties {
/*.vocab =*/ &vocab, /*.vocab =*/ &vocab,
@ -1055,8 +1029,6 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam
/*.prev =*/ ring_buffer<llama_token>(penalty_last_n), /*.prev =*/ ring_buffer<llama_token>(penalty_last_n),
}, },
}; };
return result;
} }
// logit-bias // logit-bias
@ -1093,15 +1065,13 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl(
const struct llama_vocab & vocab, const struct llama_vocab & vocab,
int32_t n_logit_bias, int32_t n_logit_bias,
const llama_logit_bias * logit_bias) { const llama_logit_bias * logit_bias) {
struct llama_constraint * result = new llama_constraint { return new llama_constraint {
/* .iface = */ &llama_constraint_logit_bias_i, /* .iface = */ &llama_constraint_logit_bias_i,
/* .ctx = */ new llama_constraint_context_logit_bias { /* .ctx = */ new llama_constraint_context_logit_bias {
/*.vocab =*/ &vocab, /*.vocab =*/ &vocab,
/*.logit_bias=*/ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias), /*.logit_bias=*/ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
}, },
}; };
return result;
} }
//////////////////////////////////////// ////////////////////////////////////////
@ -1144,7 +1114,7 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) {
// //
struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) { struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) {
auto * result = new llama_sampler { return new llama_sampler {
/* .params = */ params, /* .params = */ params,
/* .vocab = */ &vocab, /* .vocab = */ &vocab,
@ -1157,8 +1127,6 @@ struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab,
/* .t_sample_us = */ 0, /* .t_sample_us = */ 0,
/* .n_sample = */ 0, /* .n_sample = */ 0,
}; };
return result;
} }
void llama_sampler_free_impl(struct llama_sampler * smpl) { void llama_sampler_free_impl(struct llama_sampler * smpl) {