sampling : simplify new llama_sampler calls

This commit is contained in:
Georgi Gerganov 2024-09-04 20:21:02 +03:00
parent 784a644040
commit e7a11cac0e
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -439,12 +439,10 @@ static struct llama_constraint_i llama_constraint_softmax_i = {
};
struct llama_constraint * llama_constraint_init_softmax_impl() {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_softmax_i,
/* .ctx = */ nullptr,
};
return result;
}
// top-k
@ -472,15 +470,13 @@ static struct llama_constraint_i llama_constraint_top_k_i = {
};
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_top_k_i,
/* .ctx = */ new llama_constraint_context_top_k {
/*.k =*/ k,
/*.min_keep =*/ min_keep,
},
};
return result;
}
// top-p
@ -508,15 +504,13 @@ static struct llama_constraint_i llama_constraint_top_p_i = {
};
struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_top_p_i,
/* .ctx = */ new llama_constraint_context_top_p {
/*.p =*/ p,
/*.min_keep =*/ min_keep,
},
};
return result;
}
// min-p
@ -544,15 +538,13 @@ static struct llama_constraint_i llama_constraint_min_p_i = {
};
struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_min_p_i,
/* .ctx = */ new llama_constraint_context_min_p {
/*.p =*/ p,
/*.min_keep =*/ min_keep,
},
};
return result;
}
// tail-free
@ -580,15 +572,13 @@ static struct llama_constraint_i llama_constraint_tail_free_i = {
};
struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_tail_free_i,
/* .ctx = */ new llama_constraint_context_tail_free {
/*.z =*/ z,
/*.min_keep =*/ min_keep,
},
};
return result;
}
// typical
@ -616,15 +606,13 @@ static struct llama_constraint_i llama_constraint_typical_i = {
};
struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_typical_i,
/* .ctx = */ new llama_constraint_context_typical {
/*.p =*/ p,
/*.min_keep =*/ min_keep,
},
};
return result;
}
// temp
@ -651,14 +639,12 @@ static struct llama_constraint_i llama_constraint_temp_i = {
};
struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_temp_i,
/* .ctx = */ new llama_constraint_context_temp {
/*.temp =*/ temp,
},
};
return result;
}
// temp-ext
@ -694,7 +680,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = {
};
struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_temp_ext_i,
/* .ctx = */ new llama_constraint_context_temp_ext {
/*.temp =*/ temp,
@ -702,8 +688,6 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float
/*.exponent =*/ exponent,
},
};
return result;
}
// mirostat
@ -782,12 +766,8 @@ static struct llama_constraint_i llama_constraint_mirostat_i = {
},
};
struct llama_constraint * llama_constraint_init_mirostat_impl(
const struct llama_vocab & vocab,
float tau,
float eta,
int32_t m) {
struct llama_constraint * result = new llama_constraint {
struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) {
return new llama_constraint {
/* .iface = */ &llama_constraint_mirostat_i,
/* .ctx = */ new llama_constraint_context_mirostat {
/*.vocab =*/ &vocab,
@ -798,8 +778,6 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(
/*.cur =*/ {},
},
};
return result;
}
// mirostat v2
@ -863,7 +841,7 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = {
};
struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, float eta) {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_mirostat_v2_i,
/* .ctx = */ new llama_constraint_context_mirostat_v2 {
/*.tau =*/ tau,
@ -872,8 +850,6 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa
/*.cur =*/ {},
},
};
return result;
}
// grammar
@ -953,12 +929,10 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_
};
}
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_grammar_i,
/* .ctx = */ ctx,
};
return result;
}
// penalties
@ -1042,7 +1016,7 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam
GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL);
GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL);
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_penalties_i,
/* .ctx = */ new llama_constraint_context_penalties {
/*.vocab =*/ &vocab,
@ -1055,8 +1029,6 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam
/*.prev =*/ ring_buffer<llama_token>(penalty_last_n),
},
};
return result;
}
// logit-bias
@ -1093,15 +1065,13 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl(
const struct llama_vocab & vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
struct llama_constraint * result = new llama_constraint {
return new llama_constraint {
/* .iface = */ &llama_constraint_logit_bias_i,
/* .ctx = */ new llama_constraint_context_logit_bias {
/*.vocab =*/ &vocab,
/*.logit_bias=*/ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
},
};
return result;
}
////////////////////////////////////////
@ -1144,7 +1114,7 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) {
//
struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) {
auto * result = new llama_sampler {
return new llama_sampler {
/* .params = */ params,
/* .vocab = */ &vocab,
@ -1157,8 +1127,6 @@ struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab,
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
};
return result;
}
void llama_sampler_free_impl(struct llama_sampler * smpl) {