diff --git a/Makefile b/Makefile index 4f26c0463..241bbeb5a 100644 --- a/Makefile +++ b/Makefile @@ -634,7 +634,7 @@ OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h $(CXX) $(CXXFLAGS) -c $< -o $@ -COMMON_H_DEPS = common/common.h common/sampling.h common/log.h +COMMON_H_DEPS = common/common.h common/sampling.h common/log.h common/error.h COMMON_DEPS = common.o sampling.o grammar-parser.o build-info.o common.o: common/common.cpp $(COMMON_H_DEPS) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index f79acfef1..3859bba81 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -57,7 +57,8 @@ add_library(${TARGET} STATIC grammar-parser.cpp train.h train.cpp - ) + error.h +) if (BUILD_SHARED_LIBS) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/common/error.h b/common/error.h new file mode 100644 index 000000000..1669c4918 --- /dev/null +++ b/common/error.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include + +class llama_error : public std::exception +{ +private: + std::string _id; + std::string _description; + +public: + llama_error(const std::string & id, const std::string & description) + : + _id(id), + _description(description) + { + fprintf(stderr, "ERROR [%s]: %s\n", id.c_str(), description.c_str()); + } + + inline const std::string & id() const { return _id; } + inline const std::string & description() const { return _description; } +}; diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp index bf89a96f3..3d8e8bef9 100644 --- a/common/grammar-parser.cpp +++ b/common/grammar-parser.cpp @@ -1,4 +1,6 @@ #include "grammar-parser.h" +#include "error.h" + #include #include #include @@ -280,8 +282,7 @@ namespace grammar_parser { } return state; } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); - return parse_state(); + throw llama_error("grammar.invalid", std::string(__func__) + ": error parsing grammar: " + err.what()); } } diff --git a/common/sampling.cpp b/common/sampling.cpp index e67096bea..b0b960b73 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,4 +1,5 @@ #include "sampling.h" +#include "error.h" struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -8,13 +9,17 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ // if there is a grammar, parse it if (!params.grammar.empty()) { - result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + try { + result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + } catch (const llama_error & err) { + delete result; + throw err; + } // will be empty (default) if there are parse errors if (result->parsed_grammar.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); delete result; - return nullptr; + throw llama_error("grammar.empty", std::string(__func__) + ": empty grammar"); } std::vector grammar_rules(result->parsed_grammar.c_rules()); diff --git a/examples/server/README.md b/examples/server/README.md index 0e9bd7fd4..582278770 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -275,6 +275,20 @@ Notice that each `probs` is an array of length `n_probs`. - `tokens_evaluated`: Number of tokens evaluated in total from the prompt - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) +In case of an error the error details will be returned as follows: +```json +{ + "error": { + "description": "parse: error parsing grammar: expecting name at (", + "id": "grammar.invalid" + } +} +``` +where: +- `description` - human-readable description +- `id` - unique ID for this error type + + - **POST** `/tokenize`: Tokenize a given text. *Options:* diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 080fa9bd5..f9f93ccca 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -23,6 +23,7 @@ #include "index.js.hpp" #include "completion.js.hpp" #include "json-schema-to-grammar.mjs.hpp" +#include "error.h" #include #include @@ -674,6 +675,15 @@ struct llama_server_context slot->prompt = ""; } + if ( + (slot->prompt.is_string() && slot->prompt.get().empty()) + || + (slot->prompt.is_array() && slot->prompt.empty()) + ) + { + throw llama_error("prompt.empty", "The prompt must not be empty"); + } + slot->sparams.penalty_prompt_tokens.clear(); slot->sparams.use_penalty_prompt_tokens = false; const auto &penalty_prompt = data.find("penalty_prompt"); @@ -1132,6 +1142,28 @@ struct llama_server_context queue_results.send(res); } + static json error_to_json(const llama_error& error) + { + return { + { "error", { + { "id", error.id() }, + { "description", error.description() } + } } + }; + } + + void send_error(task_server& task, const llama_error& error) + { + LOG_TEE("task %i - error: %s - %s\n", task.id, error.id().c_str(), error.description().c_str()); + task_result res; + res.id = task.id; + res.multitask_id = task.multitask_id; + res.stop = false; + res.error = true; + res.result_json = { { "content", error_to_json(error).dump() } }; + queue_results.send(res); + } + json get_formated_generation(llama_client_slot &slot) { const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); @@ -1336,10 +1368,6 @@ struct llama_server_context split_multiprompt_task(task_id, task); } } else { - // an empty prompt can make slot become buggy - if (task.data.contains("prompt") && task.data["prompt"].is_string() && task.data["prompt"].get().empty()) { - task.data["prompt"] = " "; // add a space so that we have one token - } queue_tasks.post(task); } } @@ -1487,11 +1515,15 @@ struct llama_server_context slot->task_id = task.id; slot->multitask_id = task.multitask_id; - if (!launch_slot_with_data(slot, task.data)) - { - // send error result - send_error(task, "internal_error"); - break; + try { + if (!launch_slot_with_data(slot, task.data)) + { + // send error result + send_error(task, "internal_error"); + break; + } + } catch (const llama_error & err) { + send_error(task, err); } } break; case TASK_TYPE_CANCEL: { // release slot linked with the task id @@ -3129,7 +3161,15 @@ int main(int argc, char **argv) if (!validate_api_key(req, res)) { return; } - json data = json::parse(req.body); + json data; + try { + data = json::parse(req.body); + } catch(const json::exception & json_err) { + const auto err = llama_error("request.invalid_json", std::string("Invalid JSON: ") + json_err.what()); + const auto err_json = llama_server_context::error_to_json(err).dump(); + res.set_content(err_json, "text/plain; charset=utf-8"); + return; + } const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); llama.request_completion(task_id, data, false, false, -1); @@ -3141,7 +3181,6 @@ int main(int argc, char **argv) } else { - res.status = 404; res.set_content(result.result_json["content"], "text/plain; charset=utf-8"); } llama.queue_results.remove_waiting_task_id(task_id);