launch_slot_with_task

This commit is contained in:
ngxson 2024-03-09 21:13:30 +01:00
parent 02e49353c7
commit c240bd7026

View file

@ -798,9 +798,10 @@ struct server_context {
return last_used; return last_used;
} }
bool launch_slot_with_data(server_slot & slot, json data) { bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot_params default_params; slot_params default_params;
llama_sampling_params default_sparams; llama_sampling_params default_sparams;
auto & data = task.data;
if (data.count("__oaicompat") != 0) { if (data.count("__oaicompat") != 0) {
slot.oaicompat = true; slot.oaicompat = true;
@ -857,11 +858,15 @@ struct server_context {
{ {
const auto & prompt = data.find("prompt"); const auto & prompt = data.find("prompt");
if (prompt == data.end()) { if (prompt == data.end()) {
send_error(slot, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} else { } else {
slot.prompt = *prompt; slot.prompt = *prompt;
} }
if (slot.prompt.is_array() && slot.prompt.size() == 0) {
send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST);
return false;
}
} }
// penalize user-provided tokens // penalize user-provided tokens
@ -982,7 +987,7 @@ struct server_context {
slot.ctx_sampling = llama_sampling_init(slot.sparams); slot.ctx_sampling = llama_sampling_init(slot.sparams);
if (slot.ctx_sampling == nullptr) { if (slot.ctx_sampling == nullptr) {
// for now, the only error that may happen here is invalid grammar // for now, the only error that may happen here is invalid grammar
send_error(slot, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} }
llama_set_rng_seed(ctx, slot.params.seed); llama_set_rng_seed(ctx, slot.params.seed);
@ -1476,7 +1481,7 @@ struct server_context {
slot->infill = task.infill; slot->infill = task.infill;
slot->embedding = task.embedding; slot->embedding = task.embedding;
if (!launch_slot_with_data(*slot, task.data)) { if (!launch_slot_with_task(*slot, task)) {
LOG_ERROR("error while launching slot", task.data); LOG_ERROR("error while launching slot", task.data);
break; break;
} }