sampler : API to iterate constraints
ggml-ci
This commit is contained in:
parent
0e1378c844
commit
784a644040
10 changed files with 69 additions and 48 deletions
|
@ -12,7 +12,7 @@ struct gpt_sampler {
|
||||||
struct llama_sampler * smpl;
|
struct llama_sampler * smpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string gpt_sampler_params::print_all() const {
|
std::string gpt_sampler_params::print() const {
|
||||||
char result[1024];
|
char result[1024];
|
||||||
|
|
||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
|
@ -26,17 +26,12 @@ std::string gpt_sampler_params::print_all() const {
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string gpt_sampler_params::print_constraints() const {
|
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
|
||||||
std::string result = "CFG -> Penalties ";
|
std::string result = "\tlogits";
|
||||||
if (mirostat == 0) {
|
|
||||||
for (const auto & cnstr : constraints) {
|
for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) {
|
||||||
const auto name = gpt_constraint_type_to_str(cnstr);
|
const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i);
|
||||||
if (!name.empty()) {
|
result += " -> " + std::string(cnstr->iface->name(cnstr)) + " ";
|
||||||
result += "-> " + name + " ";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result += "-> mirostat ";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -70,33 +65,33 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
|
||||||
for (const auto & cnstr : params.constraints) {
|
for (const auto & cnstr : params.constraints) {
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
case GPT_CONSTRAINT_TYPE_TOP_K:
|
case GPT_CONSTRAINT_TYPE_TOP_K:
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case GPT_CONSTRAINT_TYPE_TOP_P:
|
case GPT_CONSTRAINT_TYPE_TOP_P:
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case GPT_CONSTRAINT_TYPE_MIN_P:
|
case GPT_CONSTRAINT_TYPE_MIN_P:
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case GPT_CONSTRAINT_TYPE_TFS_Z:
|
case GPT_CONSTRAINT_TYPE_TFS_Z:
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
|
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
|
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown constraint type");
|
GGML_ASSERT(false && "unknown constraint type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
|
||||||
} else if (params.mirostat == 2) {
|
} else if (params.mirostat == 2) {
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
|
||||||
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
|
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "unknown mirostat version");
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,10 +54,7 @@ struct gpt_sampler_params {
|
||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
|
|
||||||
// print the parameters into a string
|
// print the parameters into a string
|
||||||
std::string print_all() const;
|
std::string print() const;
|
||||||
|
|
||||||
// print the constraints into a string
|
|
||||||
std::string print_constraints() const;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// gpt_sampler extends llama_sampler with additional functionality:
|
// gpt_sampler extends llama_sampler with additional functionality:
|
||||||
|
@ -100,6 +97,9 @@ llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_da
|
||||||
|
|
||||||
// helpers
|
// helpers
|
||||||
|
|
||||||
|
// print the constraints into a string
|
||||||
|
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
|
||||||
|
|
||||||
// get a string representation of the last accepted tokens
|
// get a string representation of the last accepted tokens
|
||||||
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
|
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
|
||||||
|
|
||||||
|
|
|
@ -61,9 +61,9 @@ defer {
|
||||||
llama_sampler_free(smpl)
|
llama_sampler_free(smpl)
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(40, 1));
|
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1));
|
||||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(0.9, 1));
|
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1));
|
||||||
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (0.4));
|
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4));
|
||||||
|
|
||||||
let n_ctx = llama_n_ctx(context)
|
let n_ctx = llama_n_ctx(context)
|
||||||
|
|
||||||
|
|
|
@ -70,9 +70,9 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
llama_sampler * smpl = llama_sampler_init(model, sparams);
|
llama_sampler * smpl = llama_sampler_init(model, sparams);
|
||||||
|
|
||||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
|
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
|
||||||
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
|
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
|
||||||
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp));
|
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp));
|
||||||
|
|
||||||
if (ctx == NULL) {
|
if (ctx == NULL) {
|
||||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||||
|
|
|
@ -301,7 +301,7 @@ int main(int argc, char ** argv) {
|
||||||
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
|
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_TEE("sampling: \n%s\n", sparams.print_all().c_str());
|
LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
|
||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
|
|
|
@ -457,8 +457,15 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str());
|
|
||||||
LOG_TEE("sampling constr: \n%s\n", sparams.print_constraints().c_str());
|
smpl = gpt_sampler_init(model, sparams);
|
||||||
|
if (!smpl) {
|
||||||
|
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
|
||||||
|
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
|
||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
|
|
||||||
// group-attention state
|
// group-attention state
|
||||||
|
@ -525,12 +532,6 @@ int main(int argc, char ** argv) {
|
||||||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||||
}
|
}
|
||||||
|
|
||||||
smpl = gpt_sampler_init(model, sparams);
|
|
||||||
if (!smpl) {
|
|
||||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
|
||||||
exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (llama_model_has_encoder(model)) {
|
if (llama_model_has_encoder(model)) {
|
||||||
int enc_input_size = embd_inp.size();
|
int enc_input_size = embd_inp.size();
|
||||||
llama_token * enc_input_buf = embd_inp.data();
|
llama_token * enc_input_buf = embd_inp.data();
|
||||||
|
|
|
@ -1012,7 +1012,7 @@ extern "C" {
|
||||||
// The llama_sampler object contains the entire sampling information:
|
// The llama_sampler object contains the entire sampling information:
|
||||||
//
|
//
|
||||||
// - RNG state (seed and generator)
|
// - RNG state (seed and generator)
|
||||||
// - Custom set of constraints (see llama_sampler_add_constraint)
|
// - Custom set of constraints (see llama_sampler_constraint_add)
|
||||||
// - Sampling method (greedy, dist)
|
// - Sampling method (greedy, dist)
|
||||||
// - Previous tokens
|
// - Previous tokens
|
||||||
//
|
//
|
||||||
|
@ -1083,7 +1083,7 @@ extern "C" {
|
||||||
|
|
||||||
LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr);
|
LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr);
|
||||||
|
|
||||||
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint)
|
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add)
|
||||||
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
|
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);
|
||||||
|
|
||||||
LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
|
LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
|
||||||
|
@ -1102,7 +1102,10 @@ extern "C" {
|
||||||
LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl);
|
LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl);
|
||||||
|
|
||||||
// important: takes ownership of the constraint object and will free it in llama_sampler_free
|
// important: takes ownership of the constraint object and will free it in llama_sampler_free
|
||||||
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
|
LLAMA_API void llama_sampler_constraint_add( struct llama_sampler * smpl, struct llama_constraint * cnstr);
|
||||||
|
LLAMA_API int llama_sampler_n_constraints (const struct llama_sampler * smpl);
|
||||||
|
LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i);
|
||||||
|
|
||||||
|
|
||||||
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
|
LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
|
||||||
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||||
|
|
|
@ -1215,10 +1215,22 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) {
|
||||||
// TODO: should we reset the timings?
|
// TODO: should we reset the timings?
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
|
void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
|
||||||
smpl.constraints.push_back(cnstr);
|
smpl.constraints.push_back(cnstr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl) {
|
||||||
|
return smpl.constraints.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith) {
|
||||||
|
if (ith < 0 || ith >= (int) smpl.constraints.size()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return smpl.constraints[ith];
|
||||||
|
}
|
||||||
|
|
||||||
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
|
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
|
||||||
smpl.prev.push_back(token);
|
smpl.prev.push_back(token);
|
||||||
|
|
||||||
|
|
|
@ -109,7 +109,9 @@ void llama_sampler_free_impl ( struct llama_sampler * smp
|
||||||
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
|
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
|
||||||
void llama_sampler_reset_impl( struct llama_sampler & smpl);
|
void llama_sampler_reset_impl( struct llama_sampler & smpl);
|
||||||
|
|
||||||
void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr);
|
void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr);
|
||||||
|
int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl);
|
||||||
|
struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith);
|
||||||
|
|
||||||
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token);
|
void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token);
|
||||||
void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
|
void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
|
||||||
|
|
|
@ -20729,8 +20729,16 @@ llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smp
|
||||||
return &smpl->cur_p;
|
return &smpl->cur_p;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
|
void llama_sampler_constraint_add(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
|
||||||
llama_sampler_add_constraint_impl(*smpl, cnstr);
|
llama_sampler_constraint_add_impl(*smpl, cnstr);
|
||||||
|
}
|
||||||
|
|
||||||
|
int llama_sampler_n_constraints (const struct llama_sampler * smpl) {
|
||||||
|
return llama_sampler_n_constraints_impl(*smpl);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i) {
|
||||||
|
return llama_sampler_constraint_get_impl(*smpl, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue