llama : refactor sampling v2 (#9294)

- Add `struct llama_sampler` and `struct llama_sampler_i`
- Add `llama_sampler_` API
- Add `llama_sampler_chain_` API for chaining multiple samplers
- Remove `LLAMA_API_INTERNAL`
- Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
Georgi Gerganov 2024-09-07 15:16:19 +03:00 committed by GitHub
parent 947538acb8
commit df270ef745
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 3497 additions and 2914 deletions

View file

@ -2,16 +2,15 @@
#undef NDEBUG
#endif
#define LLAMA_API_INTERNAL
#include "llama.h"
#include "grammar-parser.h"
#include "llama-grammar.h"
#include <cassert>
#include <stdexcept>
int main()
{
grammar_parser::parse_state parsed_grammar;
llama_grammar_parser parsed_grammar;
std::vector<std::pair<std::string, uint32_t>> expected = {
{"expr", 2},
@ -117,7 +116,7 @@ int main()
llama_grammar * grammar = NULL;
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
if (grammar == nullptr)
{
throw std::runtime_error("Failed to initialize llama_grammar");
@ -174,13 +173,13 @@ int main()
}};
auto index = 0;
for (auto stack : llama_grammar_get_stacks(grammar))
for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar))
{
// compare stack to expected_stack
for (uint32_t i = 0; i < stack.size(); i++)
{
auto element = stack[i];
auto expected_element = expected_stacks[index][i];
const llama_grammar_element * element = stack[i];
const llama_grammar_element & expected_element = expected_stacks[index][i];
// pretty print error message before asserting
if (expected_element.type != element->type || expected_element.value != element->value)
@ -403,6 +402,8 @@ int main()
delete[] candidate.code_points;
candidate.code_points = nullptr;
}
llama_grammar_free(grammar);
llama_grammar_free_impl(grammar);
return 0;
}