Include review comments

This commit is contained in:
Martin Krasser 2023-08-07 11:46:26 +02:00
parent d9f75f3ccf
commit b6524985df

View file

@ -262,7 +262,7 @@ struct llama_server_context
return true; return true;
} }
void loadGrammar() bool loadGrammar()
{ {
if (!params.grammar.empty()) { if (!params.grammar.empty()) {
grammar_parser::parse_state parsed_grammar; grammar_parser::parse_state parsed_grammar;
@ -270,18 +270,15 @@ struct llama_server_context
parsed_grammar = grammar_parser::parse(params.grammar.c_str()); parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors // will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) { if (parsed_grammar.rules.empty()) {
fprintf(stderr, "%s: grammar parse error\n", __func__); LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
return; return false;
} }
fprintf(stderr, "%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, parsed_grammar); grammar_parser::print_grammar(stderr, parsed_grammar);
fprintf(stderr, "\n");
{ {
auto it = params.logit_bias.find(llama_token_eos()); auto it = params.logit_bias.find(llama_token_eos());
if (it != params.logit_bias.end() && it->second == -INFINITY) { if (it != params.logit_bias.end() && it->second == -INFINITY) {
fprintf(stderr, LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
"%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
} }
} }
@ -289,6 +286,7 @@ struct llama_server_context
grammar = llama_grammar_init( grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
} }
return true;
} }
void loadPrompt() void loadPrompt()
@ -1224,7 +1222,12 @@ int main(int argc, char **argv)
parse_options_completion(json::parse(req.body), llama); parse_options_completion(json::parse(req.body), llama);
llama.loadGrammar(); if (!llama.loadGrammar())
{
res.status = 400;
return;
}
llama.loadPrompt(); llama.loadPrompt();
llama.beginCompletion(); llama.beginCompletion();
@ -1376,8 +1379,12 @@ int main(int argc, char **argv)
svr.set_error_handler([](const Request &, Response &res) svr.set_error_handler([](const Request &, Response &res)
{ {
res.set_content("File Not Found", "text/plain"); if (res.status == 400) {
res.status = 404; }); res.set_content("Invalid request", "text/plain");
} else {
res.set_content("File Not Found", "text/plain");
res.status = 404;
} });
// set timeouts and change hostname and port // set timeouts and change hostname and port
svr.set_read_timeout(sparams.read_timeout); svr.set_read_timeout(sparams.read_timeout);