llama : move grammar code into llama-grammar

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-16 16:09:08 +03:00
parent 0ddc8e361c
commit 675f305f31
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
12 changed files with 742 additions and 672 deletions

View file

@ -1003,6 +1003,18 @@ extern "C" {
LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
/// @details Apply constraints from grammar
LLAMA_API void llama_grammar_sample(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar);
/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(
struct llama_context * ctx,
struct llama_grammar * grammar,
llama_token token);
//
// Sampling functions
//
@ -1121,18 +1133,6 @@ extern "C" {
struct llama_context * ctx,
llama_token_data_array * candidates);
/// @details Apply constraints from grammar
LLAMA_API void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar);
/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(
struct llama_context * ctx,
struct llama_grammar * grammar,
llama_token token);
//
// Model split
//
@ -1175,38 +1175,41 @@ extern "C" {
struct ggml_tensor;
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
struct llama_context * ctx
);
struct llama_partial_utf8 {
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};
struct llama_grammar {
const std::vector<std::vector<llama_grammar_element>> rules;
std::vector<std::vector<const llama_grammar_element *>> stacks;
// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
};
struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
llama_partial_utf8 partial_utf8;
};
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
struct llama_context * ctx
);
using llama_grammar_rules = std::vector<std::vector<llama_grammar_element>>;
using llama_grammar_stacks = std::vector<std::vector<const llama_grammar_element *>>;
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & new_stacks);
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<const llama_grammar_element *> & stack,
const std::vector<llama_grammar_candidate> & candidates);
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
llama_partial_utf8 partial_start);
llama_partial_utf8 partial_start);
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.