From 35d96805f7e504d27b8e8d8d0906b98414f44fb7 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 25 Jun 2024 05:35:39 +0200 Subject: [PATCH] squash! llama : return nullptr from llama_grammar_init Add checks for nullptr when calling llama_grammar_init. Signed-off-by: Daniel Bevenius --- common/sampling.cpp | 12 ++++++++++-- examples/gbnf-validator/gbnf-validator.cpp | 4 +++- tests/test-llama-grammar.cpp | 4 ++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f1f803516..9f332fe57 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ std::vector grammar_rules(result->parsed_grammar.c_rules()); - result->grammar = llama_grammar_init( + struct llama_grammar * grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); + if (grammar == nullptr) { + throw std::runtime_error("Failed to initialize llama_grammar"); + } + result->grammar = grammar; } result->prev.resize(params.n_prev); @@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) { if (!ctx->parsed_grammar.rules.empty()) { std::vector grammar_rules(ctx->parsed_grammar.c_rules()); - ctx->grammar = llama_grammar_init( + struct llama_grammar * grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); + if (grammar == nullptr) { + throw std::runtime_error("Failed to initialize llama_grammar"); + } + ctx->grammar = grammar; } std::fill(ctx->prev.begin(), ctx->prev.end(), 0); diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 0406dc339..dd53ba9b1 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -101,7 +101,9 @@ int main(int argc, char** argv) { auto grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - + if (grammar == nullptr) { + throw std::runtime_error("Failed to initialize llama_grammar"); + } // Read the input file std::string input_str; { diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 27ca4d265..c8badb206 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -116,6 +116,10 @@ int main() std::vector grammar_rules(parsed_grammar.c_rules()); grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + if (grammar == nullptr) + { + throw std::runtime_error("Failed to initialize llama_grammar"); + } std::vector> expected_stacks = { {