server : (embeddings) using same format for "input" and "content" (#10872)
* server : (embeddings) using same format for "input" and "content" * fix test case * handle empty input case * fix test
This commit is contained in:
parent
6b064c92b4
commit
46828872c3
3 changed files with 47 additions and 9 deletions
|
@ -3651,25 +3651,33 @@ int main(int argc, char ** argv) {
|
|||
const json body = json::parse(req.body);
|
||||
bool oaicompat = false;
|
||||
|
||||
// an input prompt can be a string or a list of tokens (integer)
|
||||
// for the shape of input/content, see tokenize_input_prompts()
|
||||
json prompt;
|
||||
if (body.count("input") != 0) {
|
||||
if (body.contains("input")) {
|
||||
oaicompat = true;
|
||||
prompt = body.at("input");
|
||||
} else if (body.count("content") != 0) {
|
||||
// with "content", we only support single prompt
|
||||
prompt = std::vector<std::string>{body.at("content")};
|
||||
} else if (body.contains("content")) {
|
||||
oaicompat = false;
|
||||
prompt = body.at("content");
|
||||
} else {
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
||||
for (const auto & tokens : tokenized_prompts) {
|
||||
// this check is necessary for models that do not add BOS token to the input
|
||||
if (tokens.empty()) {
|
||||
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
{
|
||||
std::vector<server_task> tasks;
|
||||
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
|
||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue