cont : add rest of the existing samplers [no ci]

This commit is contained in:
Georgi Gerganov 2024-09-03 15:17:02 +03:00
parent 1b07dc51c6
commit 71293a6456
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 312 additions and 10 deletions

View file

@ -1176,11 +1176,11 @@ extern "C" {
typedef void * llama_constraint_context_t; typedef void * llama_constraint_context_t;
struct llama_constraint_i { struct llama_constraint_i {
void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL
void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); // required
void (*reset) (struct llama_constraint * cnstr); // e.g. for grammar and penalty constraints, can be NULL void (*reset) (struct llama_constraint * cnstr); // can be NULL
void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); // can be NULL if ctx is NULL
void (*free) (struct llama_constraint * cnstr); // can be NULL void (*free) (struct llama_constraint * cnstr); // can be NULL
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
//void (*apply_ggml) (struct llama_constraint * cnstr, ...); //void (*apply_ggml) (struct llama_constraint * cnstr, ...);
@ -1191,9 +1191,14 @@ extern "C" {
llama_constraint_context_t ctx; llama_constraint_context_t ctx;
}; };
LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep);
// ... LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t);
LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent);
LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root);
// do not call if used with llama_sampler_add_constraint // do not call if used with llama_sampler_add_constraint
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);

View file

@ -1043,6 +1043,10 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
} }
void llama_grammar_free_impl(struct llama_grammar * grammar) { void llama_grammar_free_impl(struct llama_grammar * grammar) {
if (grammar == nullptr) {
return;
}
delete grammar; delete grammar;
} }

View file

@ -718,6 +718,269 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k
return result; return result;
} }
// min-p
struct llama_constraint_context_min_p {
float p;
size_t min_keep;
};
static struct llama_constraint_i llama_constraint_min_p_i = {
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
llama_constraint_context_min_p * ctx = (llama_constraint_context_min_p *) cnstr->ctx;
llama_sampling_min_p_impl(candidates, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) {
cnstr->ctx = new llama_constraint_context_min_p;
const auto * ctx_src = (const llama_constraint_context_min_p *) cnstr_src->ctx;
auto * ctx_dst = ( llama_constraint_context_min_p *) cnstr->ctx;
*ctx_dst = *ctx_src;
},
/* .free = */ [](struct llama_constraint * cnstr) {
if (cnstr->ctx) {
delete (llama_constraint_context_min_p *) cnstr->ctx;
}
delete cnstr;
}
};
struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) {
struct llama_constraint * result = new llama_constraint;
result->iface = &llama_constraint_min_p_i;
result->ctx = new llama_constraint_context_min_p{p, min_keep};
return result;
}
// tail-free
struct llama_constraint_context_tail_free {
float z;
size_t min_keep;
};
static struct llama_constraint_i llama_constraint_tail_free_i = {
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
llama_constraint_context_tail_free * ctx = (llama_constraint_context_tail_free *) cnstr->ctx;
llama_sampling_tail_free_impl(candidates, ctx->z, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) {
cnstr->ctx = new llama_constraint_context_tail_free;
const auto * ctx_src = (const llama_constraint_context_tail_free *) cnstr_src->ctx;
auto * ctx_dst = ( llama_constraint_context_tail_free *) cnstr->ctx;
*ctx_dst = *ctx_src;
},
/* .free = */ [](struct llama_constraint * cnstr) {
if (cnstr->ctx) {
delete (llama_constraint_context_tail_free *) cnstr->ctx;
}
delete cnstr;
}
};
struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) {
struct llama_constraint * result = new llama_constraint;
result->iface = &llama_constraint_tail_free_i;
result->ctx = new llama_constraint_context_tail_free{z, min_keep};
return result;
}
// typical
struct llama_constraint_context_typical {
float p;
size_t min_keep;
};
static struct llama_constraint_i llama_constraint_typical_i = {
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
llama_constraint_context_typical * ctx = (llama_constraint_context_typical *) cnstr->ctx;
llama_sampling_typical_impl(candidates, ctx->p, ctx->min_keep);
},
/* .reset = */ nullptr,
/* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) {
cnstr->ctx = new llama_constraint_context_typical;
const auto * ctx_src = (const llama_constraint_context_typical *) cnstr_src->ctx;
auto * ctx_dst = ( llama_constraint_context_typical *) cnstr->ctx;
*ctx_dst = *ctx_src;
},
/* .free = */ [](struct llama_constraint * cnstr) {
if (cnstr->ctx) {
delete (llama_constraint_context_typical *) cnstr->ctx;
}
delete cnstr;
}
};
struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) {
struct llama_constraint * result = new llama_constraint;
result->iface = &llama_constraint_typical_i;
result->ctx = new llama_constraint_context_typical{p, min_keep};
return result;
}
// temp
struct llama_constraint_context_temp {
float temp;
};
static struct llama_constraint_i llama_constraint_temp_i = {
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
llama_constraint_context_temp * ctx = (llama_constraint_context_temp *) cnstr->ctx;
llama_sampling_temp_impl(candidates, ctx->temp);
},
/* .reset = */ nullptr,
/* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) {
cnstr->ctx = new llama_constraint_context_temp;
const auto * ctx_src = (const llama_constraint_context_temp *) cnstr_src->ctx;
auto * ctx_dst = ( llama_constraint_context_temp *) cnstr->ctx;
*ctx_dst = *ctx_src;
},
/* .free = */ [](struct llama_constraint * cnstr) {
if (cnstr->ctx) {
delete (llama_constraint_context_temp *) cnstr->ctx;
}
delete cnstr;
}
};
struct llama_constraint * llama_constraint_init_temp_impl(float temp) {
struct llama_constraint * result = new llama_constraint;
result->iface = &llama_constraint_temp_i;
result->ctx = new llama_constraint_context_temp{temp};
return result;
}
// temp-ext
struct llama_constraint_context_temp_ext {
float temp;
float delta;
float exponent;
};
static struct llama_constraint_i llama_constraint_temp_ext_i = {
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
llama_constraint_context_temp_ext * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx;
if (ctx->delta > 0) {
const float temp_min = std::max(0.0f, ctx->temp - ctx->delta);
const float temp_max = ctx->temp + ctx->delta;
llama_sampling_entropy_impl(candidates, temp_min, temp_max, ctx->exponent);
} else {
llama_sampling_temp_impl(candidates, ctx->temp);
}
},
/* .reset = */ nullptr,
/* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) {
cnstr->ctx = new llama_constraint_context_temp_ext;
const auto * ctx_src = (const llama_constraint_context_temp_ext *) cnstr_src->ctx;
auto * ctx_dst = ( llama_constraint_context_temp_ext *) cnstr->ctx;
*ctx_dst = *ctx_src;
},
/* .free = */ [](struct llama_constraint * cnstr) {
if (cnstr->ctx) {
delete (llama_constraint_context_temp_ext *) cnstr->ctx;
}
delete cnstr;
}
};
struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) {
struct llama_constraint * result = new llama_constraint;
result->iface = &llama_constraint_temp_ext_i;
result->ctx = new llama_constraint_context_temp_ext{temp, delta, exponent};
return result;
}
// grammar
struct llama_constraint_context_grammar {
std::string grammar_str;
std::string grammar_root;
struct llama_grammar * grammar;
};
static struct llama_constraint_i llama_constraint_grammar_i = {
/* .accept = */ nullptr,
/* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) {
llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
if (ctx->grammar) {
llama_sampling_grammar_impl(candidates, *ctx->grammar);
}
},
/* .reset = */ [](struct llama_constraint * cnstr) {
llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
if (ctx->grammar) {
llama_grammar_free_impl(ctx->grammar);
ctx->grammar = nullptr;
}
if (!ctx->grammar_str.empty()) {
ctx->grammar = llama_grammar_init_impl(nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
}
},
/* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) {
cnstr->ctx = new llama_constraint_context_grammar;
const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr_src->ctx;
auto * ctx_dst = ( llama_constraint_context_grammar *) cnstr->ctx;
*ctx_dst = *ctx_src;
if (ctx_src->grammar) {
ctx_dst->grammar = llama_grammar_cp_impl(*ctx_src->grammar);
} else {
ctx_dst->grammar = nullptr;
}
},
/* .free = */ [](struct llama_constraint * cnstr) {
if (cnstr->ctx) {
{
auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx;
llama_grammar_free_impl(ctx->grammar);
}
delete (llama_constraint_context_grammar *) cnstr->ctx;
}
delete cnstr;
}
};
struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
struct llama_constraint * result = new llama_constraint;
result->iface = &llama_constraint_grammar_i;
result->ctx = new llama_constraint_context_grammar;
auto * ctx = (llama_constraint_context_grammar *) result->ctx;
if (grammar_str != nullptr && grammar_str[0] != '\0') {
ctx->grammar = llama_grammar_init_impl(&vocab, grammar_str, grammar_root);
} else {
ctx->grammar = nullptr;
}
return result;
}
void llama_constraint_free_impl(struct llama_constraint * cnstr) { void llama_constraint_free_impl(struct llama_constraint * cnstr) {
if (cnstr->iface->free && cnstr) { if (cnstr->iface->free && cnstr) {
cnstr->iface->free(cnstr); cnstr->iface->free(cnstr);

View file

@ -113,8 +113,14 @@ int llama_sampling_n_prev_impl(const struct llama_sampling & smpl);
// constraints // constraints
struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep); struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep);
struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep); struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep);
struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep);
struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep);
struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep);
struct llama_constraint * llama_constraint_init_temp_impl (float t);
struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent);
struct llama_constraint * llama_constraint_init_grammar_impl (const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root);
void llama_constraint_free_impl(struct llama_constraint * cnstr); void llama_constraint_free_impl(struct llama_constraint * cnstr);

View file

@ -20983,6 +20983,30 @@ struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep)
return llama_constraint_init_top_p_impl(p, min_keep); return llama_constraint_init_top_p_impl(p, min_keep);
} }
struct llama_constraint * llama_constraint_init_min_p(float p, int32_t min_keep) {
return llama_constraint_init_min_p_impl(p, min_keep);
}
struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep) {
return llama_constraint_init_tail_free_impl(z, min_keep);
}
struct llama_constraint * llama_constraint_init_typical(float p, int32_t min_keep) {
return llama_constraint_init_typical_impl(p, min_keep);
}
struct llama_constraint * llama_constraint_init_temp(float temp) {
return llama_constraint_init_temp_impl(temp);
}
struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta, float exponent) {
return llama_constraint_init_temp_ext_impl(temp, delta, exponent);
}
struct llama_constraint * llama_constraint_init_grammar(struct llama_model * model, const char * grammar_str, const char * grammar_root) {
return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root);
}
void llama_constraint_free(struct llama_constraint * cnstr) { void llama_constraint_free(struct llama_constraint * cnstr) {
if (cnstr == nullptr) { if (cnstr == nullptr) {
return; return;