server: error handling

This commit is contained in:
ZXED 2024-02-28 20:46:10 +03:00
parent 08c5ee87e4
commit bcb60f306f
No known key found for this signature in database
GPG key ID: 637FB44813DCFD66
7 changed files with 102 additions and 18 deletions

View file

@ -634,7 +634,7 @@ OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@
COMMON_H_DEPS = common/common.h common/sampling.h common/log.h
COMMON_H_DEPS = common/common.h common/sampling.h common/log.h common/error.h
COMMON_DEPS = common.o sampling.o grammar-parser.o build-info.o
common.o: common/common.cpp $(COMMON_H_DEPS)

View file

@ -57,7 +57,8 @@ add_library(${TARGET} STATIC
grammar-parser.cpp
train.h
train.cpp
)
error.h
)
if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)

24
common/error.h Normal file
View file

@ -0,0 +1,24 @@
#pragma once
#include <cstdio>
#include <exception>
#include <string>
class llama_error : public std::exception
{
private:
std::string _id;
std::string _description;
public:
llama_error(const std::string & id, const std::string & description)
:
_id(id),
_description(description)
{
fprintf(stderr, "ERROR [%s]: %s\n", id.c_str(), description.c_str());
}
inline const std::string & id() const { return _id; }
inline const std::string & description() const { return _description; }
};

View file

@ -1,4 +1,6 @@
#include "grammar-parser.h"
#include "error.h"
#include <cstdint>
#include <cwchar>
#include <string>
@ -280,8 +282,7 @@ namespace grammar_parser {
}
return state;
} catch (const std::exception & err) {
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
return parse_state();
throw llama_error("grammar.invalid", std::string(__func__) + ": error parsing grammar: " + err.what());
}
}

View file

@ -1,4 +1,5 @@
#include "sampling.h"
#include "error.h"
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
@ -8,13 +9,17 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
// if there is a grammar, parse it
if (!params.grammar.empty()) {
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
try {
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
} catch (const llama_error & err) {
delete result;
throw err;
}
// will be empty (default) if there are parse errors
if (result->parsed_grammar.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
delete result;
return nullptr;
throw llama_error("grammar.empty", std::string(__func__) + ": empty grammar");
}
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());

View file

@ -275,6 +275,20 @@ Notice that each `probs` is an array of length `n_probs`.
- `tokens_evaluated`: Number of tokens evaluated in total from the prompt
- `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
In case of an error the error details will be returned as follows:
```json
{
"error": {
"description": "parse: error parsing grammar: expecting name at (",
"id": "grammar.invalid"
}
}
```
where:
- `description` - human-readable description
- `id` - unique ID for this error type
- **POST** `/tokenize`: Tokenize a given text.
*Options:*

View file

@ -23,6 +23,7 @@
#include "index.js.hpp"
#include "completion.js.hpp"
#include "json-schema-to-grammar.mjs.hpp"
#include "error.h"
#include <cstddef>
#include <thread>
@ -674,6 +675,15 @@ struct llama_server_context
slot->prompt = "";
}
if (
(slot->prompt.is_string() && slot->prompt.get<std::string>().empty())
||
(slot->prompt.is_array() && slot->prompt.empty())
)
{
throw llama_error("prompt.empty", "The prompt must not be empty");
}
slot->sparams.penalty_prompt_tokens.clear();
slot->sparams.use_penalty_prompt_tokens = false;
const auto &penalty_prompt = data.find("penalty_prompt");
@ -1132,6 +1142,28 @@ struct llama_server_context
queue_results.send(res);
}
static json error_to_json(const llama_error& error)
{
return {
{ "error", {
{ "id", error.id() },
{ "description", error.description() }
} }
};
}
void send_error(task_server& task, const llama_error& error)
{
LOG_TEE("task %i - error: %s - %s\n", task.id, error.id().c_str(), error.description().c_str());
task_result res;
res.id = task.id;
res.multitask_id = task.multitask_id;
res.stop = false;
res.error = true;
res.result_json = { { "content", error_to_json(error).dump() } };
queue_results.send(res);
}
json get_formated_generation(llama_client_slot &slot)
{
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
@ -1336,10 +1368,6 @@ struct llama_server_context
split_multiprompt_task(task_id, task);
}
} else {
// an empty prompt can make slot become buggy
if (task.data.contains("prompt") && task.data["prompt"].is_string() && task.data["prompt"].get<std::string>().empty()) {
task.data["prompt"] = " "; // add a space so that we have one token
}
queue_tasks.post(task);
}
}
@ -1487,11 +1515,15 @@ struct llama_server_context
slot->task_id = task.id;
slot->multitask_id = task.multitask_id;
if (!launch_slot_with_data(slot, task.data))
{
// send error result
send_error(task, "internal_error");
break;
try {
if (!launch_slot_with_data(slot, task.data))
{
// send error result
send_error(task, "internal_error");
break;
}
} catch (const llama_error & err) {
send_error(task, err);
}
} break;
case TASK_TYPE_CANCEL: { // release slot linked with the task id
@ -3129,7 +3161,15 @@ int main(int argc, char **argv)
if (!validate_api_key(req, res)) {
return;
}
json data = json::parse(req.body);
json data;
try {
data = json::parse(req.body);
} catch(const json::exception & json_err) {
const auto err = llama_error("request.invalid_json", std::string("Invalid JSON: ") + json_err.what());
const auto err_json = llama_server_context::error_to_json(err).dump();
res.set_content(err_json, "text/plain; charset=utf-8");
return;
}
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
@ -3141,7 +3181,6 @@ int main(int argc, char **argv)
}
else
{
res.status = 404;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
}
llama.queue_results.remove_waiting_task_id(task_id);