squash! llama : return nullptr from llama_grammar_init
Add checks for nullptr when calling llama_grammar_init. Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
This commit is contained in:
parent
6189bce423
commit
35d96805f7
3 changed files with 17 additions and 3 deletions
|
@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
||||||
|
|
||||||
result->grammar = llama_grammar_init(
|
struct llama_grammar * grammar = llama_grammar_init(
|
||||||
grammar_rules.data(),
|
grammar_rules.data(),
|
||||||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||||
|
if (grammar == nullptr) {
|
||||||
|
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||||
|
}
|
||||||
|
result->grammar = grammar;
|
||||||
}
|
}
|
||||||
|
|
||||||
result->prev.resize(params.n_prev);
|
result->prev.resize(params.n_prev);
|
||||||
|
@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||||
if (!ctx->parsed_grammar.rules.empty()) {
|
if (!ctx->parsed_grammar.rules.empty()) {
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
|
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
|
||||||
|
|
||||||
ctx->grammar = llama_grammar_init(
|
struct llama_grammar * grammar = llama_grammar_init(
|
||||||
grammar_rules.data(),
|
grammar_rules.data(),
|
||||||
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
||||||
|
if (grammar == nullptr) {
|
||||||
|
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||||
|
}
|
||||||
|
ctx->grammar = grammar;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||||
|
|
|
@ -101,7 +101,9 @@ int main(int argc, char** argv) {
|
||||||
auto grammar = llama_grammar_init(
|
auto grammar = llama_grammar_init(
|
||||||
grammar_rules.data(),
|
grammar_rules.data(),
|
||||||
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||||
|
if (grammar == nullptr) {
|
||||||
|
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||||
|
}
|
||||||
// Read the input file
|
// Read the input file
|
||||||
std::string input_str;
|
std::string input_str;
|
||||||
{
|
{
|
||||||
|
|
|
@ -116,6 +116,10 @@ int main()
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||||
grammar = llama_grammar_init(
|
grammar = llama_grammar_init(
|
||||||
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||||
|
if (grammar == nullptr)
|
||||||
|
{
|
||||||
|
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
|
std::vector<std::vector<llama_grammar_element>> expected_stacks = {
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue