server: do not crash on grammar error

This commit is contained in:
ngxson 2024-03-09 15:27:42 +01:00
parent ba9c3e3192
commit de81c22abd

View file

@ -798,7 +798,7 @@ struct server_context {
return last_used; return last_used;
} }
bool launch_slot_with_data(server_slot & slot, json data) const { bool launch_slot_with_data(server_slot & slot, json data, std::string & error_message) const {
slot_params default_params; slot_params default_params;
llama_sampling_params default_sparams; llama_sampling_params default_sparams;
@ -841,12 +841,12 @@ struct server_context {
} }
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
// Might be better to reject the request with a 400 ?
LOG_WARNING("Max tokens to predict exceeds server configuration", { LOG_WARNING("Max tokens to predict exceeds server configuration", {
{"params.n_predict", slot.params.n_predict}, {"params.n_predict", slot.params.n_predict},
{"slot.n_predict", slot.n_predict}, {"slot.n_predict", slot.n_predict},
}); });
slot.params.n_predict = slot.n_predict; error_message = "Max tokens to predict exceeds server configuration";
return false;
} }
// infill // infill
@ -910,6 +910,7 @@ struct server_context {
if (logit_bias != data.end() && logit_bias->is_array()) { if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_n_vocab(model);
for (const auto & el : *logit_bias) { for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) { if (el.is_array() && el.size() == 2) {
float bias; float bias;
if (el[1].is_number()) { if (el[1].is_number()) {
@ -969,6 +970,11 @@ struct server_context {
llama_sampling_free(slot.ctx_sampling); llama_sampling_free(slot.ctx_sampling);
} }
slot.ctx_sampling = llama_sampling_init(slot.sparams); slot.ctx_sampling = llama_sampling_init(slot.sparams);
if (slot.ctx_sampling == nullptr) {
// for now, the only error that may happen here is invalid grammar
error_message = "Failed to parse grammar";
return false;
}
llama_set_rng_seed(ctx, slot.params.seed); llama_set_rng_seed(ctx, slot.params.seed);
} }
@ -1456,9 +1462,10 @@ struct server_context {
slot->infill = task.infill; slot->infill = task.infill;
slot->embedding = task.embedding; slot->embedding = task.embedding;
if (!launch_slot_with_data(*slot, task.data)) { std::string err_launch_slot;
if (!launch_slot_with_data(*slot, task.data, err_launch_slot)) {
// send error result // send error result
send_error(task, "Cannot launch slot"); send_error(task, err_launch_slot.empty() ? "Cannot launch slot" : err_launch_slot);
break; break;
} }
} break; } break;