sampler : API to iterate constraints
ggml-ci
This commit is contained in:
parent
0e1378c844
commit
784a644040
10 changed files with 69 additions and 48 deletions
|
@ -12,7 +12,7 @@ struct gpt_sampler {
|
|||
struct llama_sampler * smpl;
|
||||
};
|
||||
|
||||
std::string gpt_sampler_params::print_all() const {
|
||||
std::string gpt_sampler_params::print() const {
|
||||
char result[1024];
|
||||
|
||||
snprintf(result, sizeof(result),
|
||||
|
@ -26,17 +26,12 @@ std::string gpt_sampler_params::print_all() const {
|
|||
return std::string(result);
|
||||
}
|
||||
|
||||
std::string gpt_sampler_params::print_constraints() const {
|
||||
std::string result = "CFG -> Penalties ";
|
||||
if (mirostat == 0) {
|
||||
for (const auto & cnstr : constraints) {
|
||||
const auto name = gpt_constraint_type_to_str(cnstr);
|
||||
if (!name.empty()) {
|
||||
result += "-> " + name + " ";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result += "-> mirostat ";
|
||||
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
||||
std::string result = "\tlogits";
|
||||
|
||||
for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) {
|
||||
const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i);
|
||||
result += " -> " + std::string(cnstr->iface->name(cnstr)) + " ";
|
||||
}
|
||||
|
||||
return result;
|
||||
|
@ -70,33 +65,33 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
|||
for (const auto & cnstr : params.constraints) {
|
||||
switch (cnstr) {
|
||||
case GPT_CONSTRAINT_TYPE_TOP_K:
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
|
||||
break;
|
||||
case GPT_CONSTRAINT_TYPE_TOP_P:
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case GPT_CONSTRAINT_TYPE_MIN_P:
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case GPT_CONSTRAINT_TYPE_TFS_Z:
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
|
||||
break;
|
||||
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown constraint type");
|
||||
}
|
||||
}
|
||||
} else if (params.mirostat == 1) {
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
|
||||
} else if (params.mirostat == 2) {
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
|
||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
|
||||
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
|
||||
} else {
|
||||
GGML_ASSERT(false && "unknown mirostat version");
|
||||
}
|
||||
|
|
|
@ -54,10 +54,7 @@ struct gpt_sampler_params {
|
|||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print_all() const;
|
||||
|
||||
// print the constraints into a string
|
||||
std::string print_constraints() const;
|
||||
std::string print() const;
|
||||
};
|
||||
|
||||
// gpt_sampler extends llama_sampler with additional functionality:
|
||||
|
@ -100,6 +97,9 @@ llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_da
|
|||
|
||||
// helpers
|
||||
|
||||
// print the constraints into a string
|
||||
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
|
||||
|
||||
// get a string representation of the last accepted tokens
|
||||
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
|
||||
|
||||
|
|
|
@ -61,9 +61,9 @@ defer {
|
|||
llama_sampler_free(smpl)
|
||||
}
|
||||
|
||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(40, 1));
|
||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(0.9, 1));
|
||||
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (0.4));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4));
|
||||
|
||||
let n_ctx = llama_n_ctx(context)
|
||||
|
||||
|
|
|
@ -70,9 +70,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_sampler * smpl = llama_sampler_init(model, sparams);
|
||||
|
||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
|
||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
|
||||
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
|
||||
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp));
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
|
|
|
@ -301,7 +301,7 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
|
||||
}
|
||||
}
|
||||
LOG_TEE("sampling: \n%s\n", sparams.print_all().c_str());
|
||||
LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
|
||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
LOG_TEE("\n\n");
|
||||
|
||||
|
|
|
@ -457,8 +457,15 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
}
|
||||
LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str());
|
||||
LOG_TEE("sampling constr: \n%s\n", sparams.print_constraints().c_str());
|
||||
|
||||
smpl = gpt_sampler_init(model, sparams);
|
||||
if (!smpl) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
|
||||
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
|
||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
|
||||
// group-attention state
|
||||
|
@ -525,12 +532,6 @@ int main(int argc, char ** argv) {
|
|||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||
}
|
||||
|
||||
smpl = gpt_sampler_init(model, sparams);
|
||||
if (!smpl) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
llama_token * enc_input_buf = embd_inp.data();
|
||||
|
|
|
@ -1012,7 +1012,7 @@ extern "C" {
|
|||
// The llama_sampler object contains the entire sampling information:
|
||||
//
|
||||
// - RNG state (seed and generator)
|
||||
// - Custom set of constraints (see llama_sampler_add_constraint)
|
||||
// - Custom set of constraints (see llama_sampler_constraint_add)
|
||||
// - Sampling method (greedy, dist)
|
||||
// - Previous tokens
|
||||
//
|
||||
|
@ -1083,7 +1083,7 @@ extern "C" {
|
|||
|
||||
LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr);
|
||||
|
||||
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint)
|
||||
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add)
|
||||
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
|
||||
|
||||
LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
|
||||
|
@ -1102,7 +1102,10 @@ extern "C" {
|
|||
LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl);
|
||||
|
||||
// important: takes ownership of the constraint object and will free it in llama_sampler_free
|
||||
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
|
||||
LLAMA_API void llama_sampler_constraint_add( struct llama_sampler * smpl, struct llama_constraint * cnstr);
|
||||
LLAMA_API int llama_sampler_n_constraints (const struct llama_sampler * smpl);
|
||||
LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i);
|
||||
|
||||
|
||||
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
|
||||
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||
|
|
|
@ -1215,10 +1215,22 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) {
|
|||
// TODO: should we reset the timings?
|
||||
}
|
||||
|
||||
void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
|
||||
void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
|
||||
smpl.constraints.push_back(cnstr);
|
||||
}
|
||||
|
||||
int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl) {
|
||||
return smpl.constraints.size();
|
||||
}
|
||||
|
||||
struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith) {
|
||||
if (ith < 0 || ith >= (int) smpl.constraints.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return smpl.constraints[ith];
|
||||
}
|
||||
|
||||
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
|
||||
smpl.prev.push_back(token);
|
||||
|
||||
|
|
|
@ -109,7 +109,9 @@ void llama_sampler_free_impl ( struct llama_sampler * smp
|
|||
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
|
||||
void llama_sampler_reset_impl( struct llama_sampler & smpl);
|
||||
|
||||
void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr);
|
||||
void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr);
|
||||
int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl);
|
||||
struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith);
|
||||
|
||||
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token);
|
||||
void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
|
||||
|
|
|
@ -20729,8 +20729,16 @@ llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smp
|
|||
return &smpl->cur_p;
|
||||
}
|
||||
|
||||
void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
|
||||
llama_sampler_add_constraint_impl(*smpl, cnstr);
|
||||
void llama_sampler_constraint_add(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
|
||||
llama_sampler_constraint_add_impl(*smpl, cnstr);
|
||||
}
|
||||
|
||||
int llama_sampler_n_constraints (const struct llama_sampler * smpl) {
|
||||
return llama_sampler_n_constraints_impl(*smpl);
|
||||
}
|
||||
|
||||
struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i) {
|
||||
return llama_sampler_constraint_get_impl(*smpl, i);
|
||||
}
|
||||
|
||||
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue