sampling : add name API + option to disable timings

This commit is contained in:
Georgi Gerganov 2024-09-05 10:33:04 +03:00
parent ebeb65194b
commit 595711417a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 33 additions and 14 deletions

View file

@ -31,7 +31,7 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) {
const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i);
result += " -> " + std::string(cnstr->iface->name(cnstr)) + " ";
result += std::string(" -> ") + llama_constraint_name(cnstr) + " ";
}
return result;

View file

@ -381,6 +381,8 @@ extern "C" {
// TODO: will be used by the llama_decode_with_sampler() API in the future
enum llama_sampler_type type;
bool no_timing; // whether to measure performance timings
} llama_sampler_params;
// performance timing information
@ -1097,6 +1099,7 @@ extern "C" {
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add)
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
LLAMA_API const char * llama_constraint_name (const struct llama_constraint * cnstr);
LLAMA_API void llama_constraint_accept( struct llama_constraint * cnstr, llama_token token);
LLAMA_API void llama_constraint_apply ( struct llama_constraint * cnstr, llama_token_data_array * cur_p);
LLAMA_API void llama_constraint_reset ( struct llama_constraint * cnstr);

View file

@ -1190,6 +1190,14 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) {
// TODO: should we reset the timings?
}
const char * llama_constraint_name_impl(const struct llama_constraint & cnstr) {
if (!cnstr.iface) {
return "(null)";
}
return cnstr.iface->name(&cnstr);
}
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
smpl.prev.push_back(token);

View file

@ -62,6 +62,7 @@ struct llama_constraint * llama_constraint_clone_impl(const struct llama_constra
void llama_constraint_free_impl(struct llama_constraint * cnstr);
const char * llama_constraint_name_impl (const struct llama_constraint & cnstr);
void llama_constraint_accept_impl( struct llama_constraint & cnstr, llama_token token);
void llama_constraint_apply_impl ( struct llama_constraint & cnstr, struct llama_token_data_array * cur_p);
void llama_constraint_reset_impl ( struct llama_constraint & cnstr);

View file

@ -148,11 +148,13 @@ static void zeros(std::ofstream & file, size_t n) {
}
struct time_meas {
time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {}
time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
~time_meas() {
if (t_start_us >= 0) {
t_acc += ggml_time_us() - t_start_us;
}
}
const int64_t t_start_us;
@ -17940,6 +17942,7 @@ struct llama_sampler_params llama_sampler_default_params() {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_prev =*/ 256,
/*.type =*/ LLAMA_SAMPLER_TYPE_DIST,
/*.no_timing =*/ false, // TODO: change to true and set explicitly in examples
};
return result;
@ -20681,6 +20684,10 @@ void llama_constraint_free(struct llama_constraint * cnstr) {
llama_constraint_free_impl(cnstr);
}
const char * llama_constraint_name(const struct llama_constraint * cnstr) {
return llama_constraint_name_impl(*cnstr);
}
void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) {
llama_constraint_accept_impl(*cnstr, token);
}
@ -20718,7 +20725,7 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
}
void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
time_meas tm(smpl->t_sample_us);
time_meas tm(smpl->t_sample_us, smpl->params.no_timing);
if (cur_p == nullptr) {
cur_p = &smpl->cur_p;
@ -20756,7 +20763,7 @@ struct llama_constraint * llama_sampler_constraint_get(const struct llama_sample
}
llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
time_meas tm(smpl->t_sample_us);
time_meas tm(smpl->t_sample_us, smpl->params.no_timing);
if (cur_p == nullptr) {
cur_p = &smpl->cur_p;