sampling : add name API + option to disable timings
This commit is contained in:
parent
ebeb65194b
commit
595711417a
5 changed files with 33 additions and 14 deletions
|
@ -31,7 +31,7 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
|||
|
||||
for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) {
|
||||
const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i);
|
||||
result += " -> " + std::string(cnstr->iface->name(cnstr)) + " ";
|
||||
result += std::string(" -> ") + llama_constraint_name(cnstr) + " ";
|
||||
}
|
||||
|
||||
return result;
|
||||
|
|
|
@ -381,6 +381,8 @@ extern "C" {
|
|||
|
||||
// TODO: will be used by the llama_decode_with_sampler() API in the future
|
||||
enum llama_sampler_type type;
|
||||
|
||||
bool no_timing; // whether to measure performance timings
|
||||
} llama_sampler_params;
|
||||
|
||||
// performance timing information
|
||||
|
@ -1097,6 +1099,7 @@ extern "C" {
|
|||
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add)
|
||||
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
|
||||
|
||||
LLAMA_API const char * llama_constraint_name (const struct llama_constraint * cnstr);
|
||||
LLAMA_API void llama_constraint_accept( struct llama_constraint * cnstr, llama_token token);
|
||||
LLAMA_API void llama_constraint_apply ( struct llama_constraint * cnstr, llama_token_data_array * cur_p);
|
||||
LLAMA_API void llama_constraint_reset ( struct llama_constraint * cnstr);
|
||||
|
|
|
@ -1190,6 +1190,14 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) {
|
|||
// TODO: should we reset the timings?
|
||||
}
|
||||
|
||||
const char * llama_constraint_name_impl(const struct llama_constraint & cnstr) {
|
||||
if (!cnstr.iface) {
|
||||
return "(null)";
|
||||
}
|
||||
|
||||
return cnstr.iface->name(&cnstr);
|
||||
}
|
||||
|
||||
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
|
||||
smpl.prev.push_back(token);
|
||||
|
||||
|
|
|
@ -62,6 +62,7 @@ struct llama_constraint * llama_constraint_clone_impl(const struct llama_constra
|
|||
|
||||
void llama_constraint_free_impl(struct llama_constraint * cnstr);
|
||||
|
||||
const char * llama_constraint_name_impl (const struct llama_constraint & cnstr);
|
||||
void llama_constraint_accept_impl( struct llama_constraint & cnstr, llama_token token);
|
||||
void llama_constraint_apply_impl ( struct llama_constraint & cnstr, struct llama_token_data_array * cur_p);
|
||||
void llama_constraint_reset_impl ( struct llama_constraint & cnstr);
|
||||
|
|
|
@ -148,11 +148,13 @@ static void zeros(std::ofstream & file, size_t n) {
|
|||
}
|
||||
|
||||
struct time_meas {
|
||||
time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {}
|
||||
time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
|
||||
|
||||
~time_meas() {
|
||||
if (t_start_us >= 0) {
|
||||
t_acc += ggml_time_us() - t_start_us;
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_start_us;
|
||||
|
||||
|
@ -17940,6 +17942,7 @@ struct llama_sampler_params llama_sampler_default_params() {
|
|||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||
/*.n_prev =*/ 256,
|
||||
/*.type =*/ LLAMA_SAMPLER_TYPE_DIST,
|
||||
/*.no_timing =*/ false, // TODO: change to true and set explicitly in examples
|
||||
};
|
||||
|
||||
return result;
|
||||
|
@ -20681,6 +20684,10 @@ void llama_constraint_free(struct llama_constraint * cnstr) {
|
|||
llama_constraint_free_impl(cnstr);
|
||||
}
|
||||
|
||||
const char * llama_constraint_name(const struct llama_constraint * cnstr) {
|
||||
return llama_constraint_name_impl(*cnstr);
|
||||
}
|
||||
|
||||
void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) {
|
||||
llama_constraint_accept_impl(*cnstr, token);
|
||||
}
|
||||
|
@ -20718,7 +20725,7 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
|||
}
|
||||
|
||||
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
time_meas tm(smpl->t_sample_us);
|
||||
time_meas tm(smpl->t_sample_us, smpl->params.no_timing);
|
||||
|
||||
if (cur_p == nullptr) {
|
||||
cur_p = &smpl->cur_p;
|
||||
|
@ -20756,7 +20763,7 @@ struct llama_constraint * llama_sampler_constraint_get(const struct llama_sample
|
|||
}
|
||||
|
||||
llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
time_meas tm(smpl->t_sample_us);
|
||||
time_meas tm(smpl->t_sample_us, smpl->params.no_timing);
|
||||
|
||||
if (cur_p == nullptr) {
|
||||
cur_p = &smpl->cur_p;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue