diff --git a/include/llama.h b/include/llama.h index 244e07705..baa136537 100644 --- a/include/llama.h +++ b/include/llama.h @@ -62,7 +62,6 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; - struct llama_constraint; struct llama_sampling; typedef int32_t llama_pos; @@ -1169,20 +1168,27 @@ extern "C" { // Sampling v2 API // - // samplers - - LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params); - LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); - LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); - - LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); - - LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); - LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i); - // constraints + struct llama_constraint; + + typedef void * llama_constraint_context_t; + + struct llama_constraint_i { + void (*accept)(struct llama_constraint * cnstr, llama_token token); + void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); + void (*reset) (struct llama_constraint * cnstr); // e.g. for grammar and penalty constraints + void (*free) (struct llama_constraint * cnstr); + + // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph + //void (*apply_ggml) (struct llama_constraint * cnstr, ...); + }; + + struct llama_constraint { + struct llama_constraint_i * iface; + llama_constraint_context_t ctx; + }; + LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep); // ... @@ -1191,6 +1197,20 @@ extern "C" { LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates); + // samplers + + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params); + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); + + // TODO: should this take ownership so the user does not need to call llama_constraint_free + // or should just make a reference to the constraint so that it can be reused in multiple llama_sampler? + LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); + + LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i); + // // Model split //