diff --git a/include/llama.h b/include/llama.h index baa136537..bd756fc5c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -62,7 +62,7 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; - struct llama_sampling; + struct llama_sampling; // TODO: remove before merge typedef int32_t llama_pos; typedef int32_t llama_token; @@ -414,7 +414,8 @@ extern "C" { } llama_sampling_params; typedef struct llama_sampler_params { - bool dummy; + uint32_t seed; // the seed used to initialize the rng of the sampler + // TODO: add type of sampler: greedy, dist, mirostat, etc. } llama_sampler_params; @@ -1175,10 +1176,11 @@ extern "C" { typedef void * llama_constraint_context_t; struct llama_constraint_i { - void (*accept)(struct llama_constraint * cnstr, llama_token token); + void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); - void (*reset) (struct llama_constraint * cnstr); // e.g. for grammar and penalty constraints - void (*free) (struct llama_constraint * cnstr); + void (*reset) (struct llama_constraint * cnstr); // e.g. for grammar and penalty constraints, can be NULL + void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); + void (*free) (struct llama_constraint * cnstr); // can be NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); @@ -1192,10 +1194,13 @@ extern "C" { LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep); // ... + + // do not call if used with llama_sampler_add_constraint LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates); + LLAMA_API void llama_constraint_reset (struct llama_constraint * cnstr); // samplers @@ -1206,6 +1211,8 @@ extern "C" { // TODO: should this take ownership so the user does not need to call llama_constraint_free // or should just make a reference to the constraint so that it can be reused in multiple llama_sampler? + // + // seems better to take the ownership, otherwise the copying of the sampler will be more complicated LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8abfc3fc6..91b95d3af 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -50,6 +50,10 @@ struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & voca } void llama_sampling_free_impl(struct llama_sampling * sampling) { + if (sampling == nullptr) { + return; + } + delete sampling; } @@ -633,3 +637,166 @@ llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith int llama_sampling_n_prev_impl(const struct llama_sampling & smpl) { return smpl.prev.size(); } + +// +// sampling v2 +// + +// constraints + +// top-k + +struct llama_constraint_context_top_k { + int32_t k; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_top_k_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; + llama_sampling_top_k_impl(candidates, ctx->k, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_top_k; + const auto * ctx_src = (const llama_constraint_context_top_k *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_top_k *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + delete (llama_constraint_context_top_k *) cnstr->ctx; + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_top_k_i; + result->ctx = new llama_constraint_context_top_k{k, min_keep}; + + return result; +} + +// top-p + +struct llama_constraint_context_top_p { + float p; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_top_p_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_top_p * ctx = (llama_constraint_context_top_p *) cnstr->ctx; + llama_sampling_top_p_impl(candidates, ctx->p, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_top_p; + const auto * ctx_src = (const llama_constraint_context_top_p *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_top_p *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + delete (llama_constraint_context_top_p *) cnstr->ctx; + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_top_p_i; + result->ctx = new llama_constraint_context_top_p{p, min_keep}; + + return result; +} + +void llama_constraint_free_impl(struct llama_constraint * constraint) { + if (constraint->iface->free) { + constraint->iface->free(constraint); + } +} + +void llama_constraint_accept_impl(struct llama_constraint * constraint, llama_token token) { + if (constraint->iface->accept) { + constraint->iface->accept(constraint, token); + } +} + +void llama_constraint_apply_impl(struct llama_constraint * constraint, struct llama_token_data_array * candidates) { + GGML_ASSERT(constraint->iface->apply); + constraint->iface->apply(constraint, candidates); +} + +void llama_constraint_reset_impl(struct llama_constraint * constraint) { + if (constraint->iface->reset) { + constraint->iface->reset(constraint); + } +} + +// samplers + +struct llama_sampler * llama_sampler_init_impl(struct llama_sampler_params params) { + auto * result = new llama_sampler; + + result->params = params; + + result->rng.seed(params.seed); + + return result; +} + +void llama_sampler_free_impl(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; + } + + for (auto * constraint : smpl->constraints) { + llama_constraint_free_impl(constraint); + } + + delete smpl; +} + +struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) { + auto * result = new llama_sampler; + + *result = smpl; + + // copy the constraints objects + result->constraints.clear(); + for (const auto & constraint : smpl.constraints) { + GGML_ASSERT(constraint->iface->copy); + + result->constraints.push_back(new llama_constraint); + result->constraints.back()->iface = constraint->iface; + result->constraints.back()->iface->copy(result->constraints.back(), constraint); + } + + return result; +} + +void llama_sampler_reset_impl(struct llama_sampler & smpl) { + smpl.prev.clear(); + + for (auto * constraint : smpl.constraints) { + llama_constraint_reset_impl(constraint); + } + + // TODO: should we reset the timings? +} + +void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { + smpl.constraints.push_back(cnstr); +} + +void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { + smpl.prev.push_back(token); + + for (auto * constraint : smpl.constraints) { + llama_constraint_accept_impl(constraint, token); + } +} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index c51542259..4141c4fa3 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -105,3 +105,48 @@ void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith); int llama_sampling_n_prev_impl(const struct llama_sampling & smpl); + + +// +// sampling v2 +// + +// constraints + +struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep); + +void llama_constraint_free_impl(struct llama_constraint * constraint); + +void llama_constraint_accept_impl(struct llama_constraint * constraint, llama_token token); +void llama_constraint_apply_impl (struct llama_constraint * constraint, struct llama_token_data_array * candidates); +void llama_constraint_reset_impl (struct llama_constraint * constraint); + +// samplers + +struct llama_sampler { + llama_sampler_params params; + + // state + + std::mt19937 rng; + + // TODO: move to a standalone penalty constraint? + ring_buffer prev; + + std::vector constraints; + + // timing + + mutable int64_t t_sample_us = 0; + + mutable int32_t n_sample = 0; +}; + +struct llama_sampler * llama_sampler_init_impl ( struct llama_sampler_params params); +void llama_sampler_free_impl ( struct llama_sampler * smpl); +struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); +void llama_sampler_reset_impl( struct llama_sampler & smpl); + +void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr); +void llama_sampler_accept_impl (struct llama_sampler & smpl, llama_token token); diff --git a/src/llama.cpp b/src/llama.cpp index d0ca96aca..5f06f33ad 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17937,7 +17937,7 @@ struct llama_context_params llama_context_default_params() { struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { - /*.dummy =*/ false, + /*.seed =*/ LLAMA_DEFAULT_SEED, }; return result; @@ -20971,6 +20971,70 @@ llama_token llama_sampling_last(const struct llama_sampling * smpl) { return llama_sampling_prev_impl(*smpl, 0); } +// +// sampling v2 +// + +struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) { + return llama_constraint_init_top_k_impl(k, min_keep); +} + +struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) { + return llama_constraint_init_top_p_impl(p, min_keep); +} + +void llama_constraint_free(struct llama_constraint * cnstr) { + if (cnstr == nullptr) { + return; + } + + llama_constraint_free_impl(cnstr); +} + +void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) { + llama_constraint_accept_impl(cnstr, token); +} + +void llama_constraint_apply(struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_apply_impl(cnstr, candidates); +} + +void llama_constraint_reset(struct llama_constraint * cnstr) { + llama_constraint_reset_impl(cnstr); +} + +struct llama_sampler * llama_sampler_init(struct llama_sampler_params params) { + return llama_sampler_init_impl(params); +} + +void llama_sampler_free(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; + } + + llama_sampler_free_impl(smpl); +} + +struct llama_sampler * llama_sampler_cp(const struct llama_sampler * smpl) { + return llama_sampler_cp_impl(*smpl); +} + +void llama_sampler_reset(struct llama_sampler * smpl) { + llama_sampler_reset_impl(*smpl); +} + +void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) { + llama_sampler_add_constraint_impl(*smpl, cnstr); +} + +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + llama_sampler_accept_impl(*smpl, token); +} + +llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) { + GGML_ABORT("not implemented"); +} + // // model split //