server: add OpenAI compatible response format for /completions with backward compatibility
This commit is contained in:
parent
3b4f2e33e2
commit
938dbd4d6e
9 changed files with 470 additions and 186 deletions
|
@ -924,7 +924,7 @@ struct server_context {
|
||||||
slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
||||||
slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
|
slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
|
||||||
slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||||
slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
slot.params.sampling.n_probs = json_value(data, "n_probs", json_value(data, "logprobs", defaults.sampling.n_probs));
|
||||||
slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||||
|
|
||||||
slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||||
|
@ -1340,7 +1340,8 @@ struct server_context {
|
||||||
}
|
}
|
||||||
slot.n_sent_token_probs = probs_stop_pos;
|
slot.n_sent_token_probs = probs_stop_pos;
|
||||||
|
|
||||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
|
// TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs
|
||||||
|
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.oaicompat) {
|
if (slot.oaicompat) {
|
||||||
|
@ -1379,7 +1380,7 @@ struct server_context {
|
||||||
{"timings", slot.get_formated_timings()},
|
{"timings", slot.get_formated_timings()},
|
||||||
{"index", slot.index},
|
{"index", slot.index},
|
||||||
};
|
};
|
||||||
|
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
std::vector<completion_token_output> probs;
|
std::vector<completion_token_output> probs;
|
||||||
if (!slot.params.stream && slot.stopped_word) {
|
if (!slot.params.stream && slot.stopped_word) {
|
||||||
|
@ -1395,7 +1396,8 @@ struct server_context {
|
||||||
slot.generated_token_probs.end());
|
slot.generated_token_probs.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
|
// TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs
|
||||||
|
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.oaicompat) {
|
if (slot.oaicompat) {
|
||||||
|
@ -2901,31 +2903,63 @@ int main(int argc, char ** argv) {
|
||||||
res_ok(res, {{ "success", true }});
|
res_ok(res, {{ "success", true }});
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
const auto handle_completions_generic = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](server_task_inf_type inf_type, json & data, httplib::Response & res, bool is_chat = false) {
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse data for /chat/completions format if needed
|
||||||
|
if (is_chat) {
|
||||||
|
data = oaicompat_completion_params_parse(ctx_server.model, data, params.chat_template);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(tasks);
|
||||||
|
|
||||||
bool stream = json_value(data, "stream", false);
|
bool stream = json_value(data, "stream", false);
|
||||||
|
bool oai_compat = json_value(data, "oai_compat", true);
|
||||||
const auto task_ids = server_task::get_list_id(tasks);
|
const auto task_ids = server_task::get_list_id(tasks);
|
||||||
|
const auto completion_id = gen_chatcmplid();
|
||||||
|
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
||||||
if (results.size() == 1) {
|
if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) {
|
||||||
// single result
|
if (is_chat) {
|
||||||
res_ok(res, results[0].data);
|
// multitask is never supported in chat completion, there is only one result
|
||||||
} else {
|
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id,
|
||||||
// multiple results (multitask)
|
/*.streaming =*/ false, verbose, /*.legacy_format =*/ !is_chat);
|
||||||
json arr = json::array();
|
res_ok(res, result_oai);
|
||||||
for (const auto & res : results) {
|
} else {
|
||||||
arr.push_back(res.data);
|
if (results.size() == 1) {
|
||||||
|
// single result
|
||||||
|
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id,
|
||||||
|
/*.streaming =*/ false, verbose, /*.legacy_format =*/ true);
|
||||||
|
res_ok(res, result_oai);
|
||||||
|
} else {
|
||||||
|
// multiple results (multitask)
|
||||||
|
json arr = json::array();
|
||||||
|
for (const auto & result : results) {
|
||||||
|
arr.push_back(format_final_response_oaicompat(data, result.data, completion_id,
|
||||||
|
/*.streaming =*/ false, verbose, /*.legacy_format =*/ true));
|
||||||
|
}
|
||||||
|
res_ok(res, arr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
if (results.size() == 1) {
|
||||||
|
// single result
|
||||||
|
res_ok(res, results[0].data);
|
||||||
|
} else {
|
||||||
|
// multiple results (multitask)
|
||||||
|
json arr = json::array();
|
||||||
|
for (const auto & res : results) {
|
||||||
|
arr.push_back(res.data);
|
||||||
|
}
|
||||||
|
res_ok(res, arr);
|
||||||
}
|
}
|
||||||
res_ok(res, arr);
|
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
res_error(res, error_data);
|
res_error(res, error_data);
|
||||||
|
@ -2933,14 +2967,35 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, is_chat, inf_type, oai_compat](size_t, httplib::DataSink & sink) {
|
||||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
||||||
return server_sent_event(sink, "data", result.data);
|
if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) {
|
||||||
|
|
||||||
|
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, !is_chat);
|
||||||
|
for (auto & event_data : result_array) {
|
||||||
|
if (event_data.empty()) {
|
||||||
|
continue; // skip the stop token
|
||||||
|
}
|
||||||
|
if (!server_sent_event(sink, "data", event_data)) {
|
||||||
|
return false; // connection is closed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true; // ok
|
||||||
|
|
||||||
|
}
|
||||||
|
return server_sent_event(sink, "data", result.data);
|
||||||
|
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
server_sent_event(sink, "error", error_data);
|
server_sent_event(sink, "error", error_data);
|
||||||
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (is_chat) {
|
||||||
|
static const std::string ev_done = "data: [DONE]\n\n";
|
||||||
|
sink.write(ev_done.data(), ev_done.size());
|
||||||
|
}
|
||||||
sink.done();
|
sink.done();
|
||||||
return false;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto on_complete = [task_ids, &ctx_server] (bool) {
|
auto on_complete = [task_ids, &ctx_server] (bool) {
|
||||||
|
@ -2953,7 +3008,12 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto handle_chat_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
json data = json::parse(req.body);
|
||||||
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
@ -3006,63 +3066,6 @@ int main(int argc, char ** argv) {
|
||||||
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: maybe merge this function with "handle_completions_generic"
|
|
||||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
|
||||||
if (ctx_server.params_base.embedding) {
|
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
|
||||||
|
|
||||||
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
||||||
ctx_server.queue_tasks.post(tasks);
|
|
||||||
|
|
||||||
bool stream = json_value(data, "stream", false);
|
|
||||||
const auto task_ids = server_task::get_list_id(tasks);
|
|
||||||
const auto completion_id = gen_chatcmplid();
|
|
||||||
|
|
||||||
if (!stream) {
|
|
||||||
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
|
||||||
// multitask is never support in chat completion, there is only one result
|
|
||||||
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
|
||||||
res_ok(res, result_oai);
|
|
||||||
}, [&](const json & error_data) {
|
|
||||||
res_error(res, error_data);
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
||||||
} else {
|
|
||||||
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
|
||||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
|
||||||
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
|
||||||
for (auto & event_data : result_array) {
|
|
||||||
if (event_data.empty()) {
|
|
||||||
continue; // skip the stop token
|
|
||||||
}
|
|
||||||
if (!server_sent_event(sink, "data", event_data)) {
|
|
||||||
return false; // connection is closed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true; // ok
|
|
||||||
}, [&](const json & error_data) {
|
|
||||||
server_sent_event(sink, "error", error_data);
|
|
||||||
});
|
|
||||||
static const std::string ev_done = "data: [DONE]\n\n";
|
|
||||||
sink.write(ev_done.data(), ev_done.size());
|
|
||||||
sink.done();
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
auto on_complete = [task_ids, &ctx_server] (bool) {
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
||||||
};
|
|
||||||
|
|
||||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||||
json models = {
|
json models = {
|
||||||
{"object", "list"},
|
{"object", "list"},
|
||||||
|
|
|
@ -40,9 +40,19 @@ def test_load_split_model():
|
||||||
server.model_alias = "tinyllama-split"
|
server.model_alias = "tinyllama-split"
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"n_predict": 16,
|
"max_tokens": 16,
|
||||||
"prompt": "Hello",
|
"prompt": "Hello",
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(little|girl)+", res.body["content"])
|
# Verify response structure
|
||||||
|
assert "id" in res.body
|
||||||
|
assert "object" in res.body
|
||||||
|
assert "created" in res.body
|
||||||
|
assert "model" in res.body
|
||||||
|
assert "choices" in res.body
|
||||||
|
assert isinstance(res.body["choices"], list)
|
||||||
|
assert len(res.body["choices"]) > 0
|
||||||
|
assert "text" in res.body["choices"][0]
|
||||||
|
# Verify the actual content
|
||||||
|
assert match_regex("(little|girl)+", res.body["choices"][0]["text"])
|
||||||
|
|
|
@ -18,13 +18,13 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"n_predict": n_predict,
|
"max_tokens": n_predict,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert res.body["timings"]["prompt_n"] == n_prompt
|
assert res.body["timings"]["prompt_n"] == n_prompt
|
||||||
assert res.body["timings"]["predicted_n"] == n_predicted
|
assert res.body["timings"]["predicted_n"] == n_predicted
|
||||||
assert res.body["truncated"] == truncated
|
|
||||||
assert match_regex(re_content, res.body["content"])
|
assert match_regex(re_content, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,16 +36,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_stream_request("POST", "/completion", data={
|
res = server.make_stream_request("POST", "/completion", data={
|
||||||
"n_predict": n_predict,
|
"max_tokens": n_predict,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
content = ""
|
content = ""
|
||||||
for data in res:
|
for data in res:
|
||||||
if data["stop"]:
|
if data["stop"]:
|
||||||
assert data["timings"]["prompt_n"] == n_prompt
|
assert data["timings"]["prompt_n"] == n_prompt
|
||||||
assert data["timings"]["predicted_n"] == n_predicted
|
assert data["timings"]["predicted_n"] == n_predicted
|
||||||
assert data["truncated"] == truncated
|
|
||||||
assert match_regex(re_content, content)
|
assert match_regex(re_content, content)
|
||||||
else:
|
else:
|
||||||
content += data["content"]
|
content += data["content"]
|
||||||
|
@ -63,6 +63,7 @@ def test_consistent_result_same_seed(n_slots: int):
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
if last_res is not None:
|
if last_res is not None:
|
||||||
assert res.body["content"] == last_res.body["content"]
|
assert res.body["content"] == last_res.body["content"]
|
||||||
|
@ -81,6 +82,7 @@ def test_different_result_different_seed(n_slots: int):
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
if last_res is not None:
|
if last_res is not None:
|
||||||
assert res.body["content"] != last_res.body["content"]
|
assert res.body["content"] != last_res.body["content"]
|
||||||
|
@ -100,6 +102,7 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
if last_res is not None:
|
if last_res is not None:
|
||||||
assert res.body["content"] == last_res.body["content"]
|
assert res.body["content"] == last_res.body["content"]
|
||||||
|
@ -115,12 +118,14 @@ def test_cache_vs_nocache_prompt():
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"cache_prompt": True,
|
"cache_prompt": True,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
res_no_cache = server.make_request("POST", "/completion", data={
|
res_no_cache = server.make_request("POST", "/completion", data={
|
||||||
"prompt": "I believe the meaning of life is",
|
"prompt": "I believe the meaning of life is",
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"cache_prompt": False,
|
"cache_prompt": False,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res_cache.body["content"] == res_no_cache.body["content"]
|
assert res_cache.body["content"] == res_no_cache.body["content"]
|
||||||
|
|
||||||
|
@ -140,6 +145,7 @@ def test_completion_with_tokens_input():
|
||||||
# single completion
|
# single completion
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"prompt": tokens,
|
"prompt": tokens,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert type(res.body["content"]) == str
|
assert type(res.body["content"]) == str
|
||||||
|
@ -147,6 +153,7 @@ def test_completion_with_tokens_input():
|
||||||
# batch completion
|
# batch completion
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"prompt": [tokens, tokens],
|
"prompt": [tokens, tokens],
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert type(res.body) == list
|
assert type(res.body) == list
|
||||||
|
@ -156,6 +163,7 @@ def test_completion_with_tokens_input():
|
||||||
# mixed string and tokens
|
# mixed string and tokens
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"prompt": [tokens, prompt_str],
|
"prompt": [tokens, prompt_str],
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert type(res.body) == list
|
assert type(res.body) == list
|
||||||
|
@ -165,6 +173,7 @@ def test_completion_with_tokens_input():
|
||||||
# mixed string and tokens in one sequence
|
# mixed string and tokens in one sequence
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
|
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert type(res.body["content"]) == str
|
assert type(res.body["content"]) == str
|
||||||
|
@ -208,6 +217,7 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
|
"oai_compat": False,
|
||||||
})))
|
})))
|
||||||
tasks.append((check_slots_status, ()))
|
tasks.append((check_slots_status, ()))
|
||||||
results = parallel_function_calls(tasks)
|
results = parallel_function_calls(tasks)
|
||||||
|
@ -221,3 +231,122 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||||
assert len(res.body["content"]) > 10
|
assert len(res.body["content"]) > 10
|
||||||
# FIXME: the result is not deterministic when using other slot than slot 0
|
# FIXME: the result is not deterministic when using other slot than slot 0
|
||||||
# assert match_regex(re_content, res.body["content"])
|
# assert match_regex(re_content, res.body["content"])
|
||||||
|
|
||||||
|
# OpenAI legacy completion endpoint tests
|
||||||
|
@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [
|
||||||
|
("I believe the meaning of life is", 8, "going to bed", 18, 8),
|
||||||
|
("Write a joke about", 16, "Why did the AI", 14, 16),
|
||||||
|
])
|
||||||
|
def test_completion_openai(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# Test non-streaming response
|
||||||
|
res = server.make_request("POST", "/completions", data={
|
||||||
|
"model": "local-model",
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": n_predict,
|
||||||
|
"logprobs": 3,
|
||||||
|
"echo": True
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["object"] == "text_completion"
|
||||||
|
assert isinstance(res.body["id"], str)
|
||||||
|
assert isinstance(res.body["created"], int)
|
||||||
|
assert res.body["model"] == "local-model"
|
||||||
|
|
||||||
|
# Check choices array
|
||||||
|
assert len(res.body["choices"]) == 1
|
||||||
|
choice = res.body["choices"][0]
|
||||||
|
assert choice["index"] == 0
|
||||||
|
assert isinstance(choice["text"], str)
|
||||||
|
assert choice["finish_reason"] in ["stop", "length"]
|
||||||
|
|
||||||
|
# Check logprobs
|
||||||
|
assert choice["logprobs"] is not None
|
||||||
|
assert "tokens" in choice["logprobs"]
|
||||||
|
assert "token_logprobs" in choice["logprobs"]
|
||||||
|
assert "top_logprobs" in choice["logprobs"]
|
||||||
|
assert len(choice["logprobs"]["top_logprobs"]) == len(choice["logprobs"]["tokens"])
|
||||||
|
|
||||||
|
# Check usage statistics
|
||||||
|
assert "usage" in res.body
|
||||||
|
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||||
|
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||||
|
assert res.body["usage"]["total_tokens"] == n_prompt + n_predicted
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [
|
||||||
|
("I believe the meaning of life is", 8, "going to bed", 18, 8),
|
||||||
|
("Write a joke about", 16, "Why did the AI", 14, 16),
|
||||||
|
])
|
||||||
|
def test_completion_openai_stream(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
res = server.make_stream_request("POST", "/v1/completions", data={
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": n_predict,
|
||||||
|
"stream": True,
|
||||||
|
})
|
||||||
|
|
||||||
|
collected_text = ""
|
||||||
|
is_first_chunk = True
|
||||||
|
for data in res:
|
||||||
|
assert "id" in data
|
||||||
|
assert data["object"] == "text_completion"
|
||||||
|
assert isinstance(data["created"], int)
|
||||||
|
|
||||||
|
assert len(data["choices"]) == 1
|
||||||
|
choice = data["choices"][0]
|
||||||
|
assert choice["index"] == 0
|
||||||
|
assert isinstance(choice["text"], str)
|
||||||
|
collected_text += choice["text"]
|
||||||
|
|
||||||
|
if is_first_chunk:
|
||||||
|
# First chunk should have model info
|
||||||
|
is_first_chunk = False
|
||||||
|
|
||||||
|
if choice["finish_reason"] is not None:
|
||||||
|
# This is the last chunk
|
||||||
|
assert choice["finish_reason"] in ["stop", "length"]
|
||||||
|
assert "usage" in data
|
||||||
|
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||||
|
assert data["usage"]["completion_tokens"] == n_predicted
|
||||||
|
assert data["usage"]["total_tokens"] == n_prompt + n_predicted
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [
|
||||||
|
("I believe the meaning of life is", 8, "going to bed", 18, 8),
|
||||||
|
("Write a joke about", 16, "Why did the AI", 14, 16),
|
||||||
|
])
|
||||||
|
def test_completion_openai_no_logprobs(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int):
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# Test non-streaming response
|
||||||
|
res = server.make_request("POST", "/completions", data={
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": n_predict,
|
||||||
|
"echo": True
|
||||||
|
})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.body["object"] == "text_completion"
|
||||||
|
assert isinstance(res.body["id"], str)
|
||||||
|
assert isinstance(res.body["created"], int)
|
||||||
|
|
||||||
|
# Check choices array
|
||||||
|
assert len(res.body["choices"]) == 1
|
||||||
|
choice = res.body["choices"][0]
|
||||||
|
assert choice["index"] == 0
|
||||||
|
assert isinstance(choice["text"], str)
|
||||||
|
assert choice["finish_reason"] in ["stop", "length"]
|
||||||
|
|
||||||
|
# Verify logprobs is None when not requested
|
||||||
|
assert choice["logprobs"] is None
|
||||||
|
|
||||||
|
# Check usage statistics
|
||||||
|
assert "usage" in res.body
|
||||||
|
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||||
|
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||||
|
assert res.body["usage"]["total_tokens"] == n_prompt + n_predicted
|
|
@ -29,6 +29,7 @@ def test_ctx_shift_enabled():
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"n_predict": 64,
|
"n_predict": 64,
|
||||||
"prompt": LONG_TEXT,
|
"prompt": LONG_TEXT,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert res.body["timings"]["prompt_n"] == 109
|
assert res.body["timings"]["prompt_n"] == 109
|
||||||
|
@ -48,6 +49,7 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"n_predict": n_predict,
|
"n_predict": n_predict,
|
||||||
"prompt": "Hi how are you",
|
"prompt": "Hi how are you",
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert res.body["timings"]["predicted_n"] == n_token_output
|
assert res.body["timings"]["predicted_n"] == n_token_output
|
||||||
|
@ -61,6 +63,7 @@ def test_ctx_shift_disabled_long_prompt():
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"n_predict": 64,
|
"n_predict": 64,
|
||||||
"prompt": LONG_TEXT,
|
"prompt": LONG_TEXT,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code != 200
|
assert res.status_code != 200
|
||||||
assert "error" in res.body
|
assert "error" in res.body
|
||||||
|
|
|
@ -36,6 +36,7 @@ def test_lora(scale: float, re_content: str):
|
||||||
assert res_lora_control.status_code == 200
|
assert res_lora_control.status_code == 200
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"prompt": "Look in thy glass",
|
"prompt": "Look in thy glass",
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex(re_content, res.body["content"])
|
assert match_regex(re_content, res.body["content"])
|
||||||
|
|
|
@ -41,6 +41,7 @@ def test_correct_api_key():
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completions", data={
|
res = server.make_request("POST", "/completions", data={
|
||||||
"prompt": "I believe the meaning of life is",
|
"prompt": "I believe the meaning of life is",
|
||||||
|
"oai_compat": False,
|
||||||
}, headers={
|
}, headers={
|
||||||
"Authorization": f"Bearer {TEST_API_KEY}",
|
"Authorization": f"Bearer {TEST_API_KEY}",
|
||||||
})
|
})
|
||||||
|
|
|
@ -20,6 +20,7 @@ def test_slot_save_restore():
|
||||||
"prompt": "What is the capital of France?",
|
"prompt": "What is the capital of France?",
|
||||||
"id_slot": 1,
|
"id_slot": 1,
|
||||||
"cache_prompt": True,
|
"cache_prompt": True,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
||||||
|
@ -37,6 +38,7 @@ def test_slot_save_restore():
|
||||||
"prompt": "What is the capital of Germany?",
|
"prompt": "What is the capital of Germany?",
|
||||||
"id_slot": 1,
|
"id_slot": 1,
|
||||||
"cache_prompt": True,
|
"cache_prompt": True,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Jack|said)+", res.body["content"])
|
assert match_regex("(Jack|said)+", res.body["content"])
|
||||||
|
@ -54,6 +56,7 @@ def test_slot_save_restore():
|
||||||
"prompt": "What is the capital of Germany?",
|
"prompt": "What is the capital of Germany?",
|
||||||
"id_slot": 0,
|
"id_slot": 0,
|
||||||
"cache_prompt": True,
|
"cache_prompt": True,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Jack|said)+", res.body["content"])
|
assert match_regex("(Jack|said)+", res.body["content"])
|
||||||
|
@ -64,6 +67,7 @@ def test_slot_save_restore():
|
||||||
"prompt": "What is the capital of Germany?",
|
"prompt": "What is the capital of Germany?",
|
||||||
"id_slot": 1,
|
"id_slot": 1,
|
||||||
"cache_prompt": True,
|
"cache_prompt": True,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Jack|said)+", res.body["content"])
|
assert match_regex("(Jack|said)+", res.body["content"])
|
||||||
|
@ -78,6 +82,7 @@ def test_slot_erase():
|
||||||
"prompt": "What is the capital of France?",
|
"prompt": "What is the capital of France?",
|
||||||
"id_slot": 1,
|
"id_slot": 1,
|
||||||
"cache_prompt": True,
|
"cache_prompt": True,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
||||||
|
@ -94,5 +99,5 @@ def test_slot_erase():
|
||||||
"cache_prompt": True,
|
"cache_prompt": True,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex("(Whiskers|Flana)+", res.body["content"])
|
assert match_regex("(Whiskers|Flana)+", res.body["choices"][0]["text"])
|
||||||
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
|
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
|
||||||
|
|
|
@ -37,6 +37,7 @@ def test_with_and_without_draft():
|
||||||
"prompt": "I believe the meaning of life is",
|
"prompt": "I believe the meaning of life is",
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
content_no_draft = res.body["content"]
|
content_no_draft = res.body["content"]
|
||||||
|
@ -49,6 +50,7 @@ def test_with_and_without_draft():
|
||||||
"prompt": "I believe the meaning of life is",
|
"prompt": "I believe the meaning of life is",
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
content_draft = res.body["content"]
|
content_draft = res.body["content"]
|
||||||
|
@ -75,6 +77,7 @@ def test_different_draft_min_draft_max():
|
||||||
"prompt": "I believe the meaning of life is",
|
"prompt": "I believe the meaning of life is",
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
|
"oai_compat": False,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
if last_content is not None:
|
if last_content is not None:
|
||||||
|
@ -96,6 +99,7 @@ def test_multi_requests_parallel(n_slots: int, n_requests: int):
|
||||||
"prompt": "I believe the meaning of life is",
|
"prompt": "I believe the meaning of life is",
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
|
"oai_compat": False,
|
||||||
})))
|
})))
|
||||||
results = parallel_function_calls(tasks)
|
results = parallel_function_calls(tasks)
|
||||||
for res in results:
|
for res in results:
|
||||||
|
|
|
@ -498,28 +498,105 @@ struct completion_token_output {
|
||||||
};
|
};
|
||||||
|
|
||||||
// convert a vector of completion_token_output to json
|
// convert a vector of completion_token_output to json
|
||||||
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
|
static json probs_vector_to_json(llama_context * ctx, const std::vector<completion_token_output> & probs, bool legacy_format = true) {
|
||||||
json out = json::array();
|
if (legacy_format) {
|
||||||
|
// Legacy format (text_completion endpoint)
|
||||||
for (const auto & prob : probs) {
|
json logprobs;
|
||||||
json probs_for_token = json::array();
|
std::vector<std::string> tokens;
|
||||||
|
std::vector<json> token_logprobs; // Changed to json to allow null values
|
||||||
for (const auto & p : prob.probs) {
|
std::vector<json> top_logprobs; // Changed to allow null values
|
||||||
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
|
std::vector<int> text_offset;
|
||||||
probs_for_token.push_back(json {
|
|
||||||
{"tok_str", tok_str},
|
int current_offset = 0;
|
||||||
{"prob", p.prob},
|
|
||||||
});
|
for (const auto & prob : probs) {
|
||||||
|
std::string token_str = tokens_to_output_formatted_string(ctx, prob.tok);
|
||||||
|
tokens.push_back(token_str);
|
||||||
|
text_offset.push_back(current_offset);
|
||||||
|
|
||||||
|
// Handle token logprobs
|
||||||
|
if (!prob.probs.empty() && prob.probs[0].prob > 0) {
|
||||||
|
token_logprobs.push_back(std::log(prob.probs[0].prob));
|
||||||
|
} else {
|
||||||
|
token_logprobs.push_back(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle top logprobs
|
||||||
|
json token_top_logprobs = json::object();
|
||||||
|
for (const auto & p : prob.probs) {
|
||||||
|
if (p.prob > 0) {
|
||||||
|
token_top_logprobs[tokens_to_output_formatted_string(ctx, p.tok)] = std::log(p.prob);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
top_logprobs.push_back(token_top_logprobs.empty() ? nullptr : token_top_logprobs);
|
||||||
|
|
||||||
|
current_offset += token_str.length();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logprobs = {
|
||||||
|
{"tokens", tokens},
|
||||||
|
{"token_logprobs", token_logprobs},
|
||||||
|
{"top_logprobs", top_logprobs},
|
||||||
|
{"text_offset", text_offset}
|
||||||
|
};
|
||||||
|
|
||||||
|
return logprobs;
|
||||||
|
} else {
|
||||||
|
// New format (GPT-4 style)
|
||||||
|
json logprobs;
|
||||||
|
std::vector<json> content;
|
||||||
|
|
||||||
|
for (const auto & prob : probs) {
|
||||||
|
std::string token_str = tokens_to_output_formatted_string(ctx, prob.tok);
|
||||||
|
|
||||||
|
// Create top_logprobs array for this token
|
||||||
|
json token_top_logprobs = json::array();
|
||||||
|
for (const auto & p : prob.probs) {
|
||||||
|
if (p.prob > 0) {
|
||||||
|
// Get UTF-8 bytes for the token
|
||||||
|
std::vector<int> bytes;
|
||||||
|
std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
|
||||||
|
for (unsigned char c : tok_str) {
|
||||||
|
bytes.push_back(static_cast<int>(c));
|
||||||
|
}
|
||||||
|
|
||||||
|
json entry = {
|
||||||
|
{"token", tok_str},
|
||||||
|
{"logprob", std::log(p.prob)},
|
||||||
|
{"bytes", bytes.empty() ? json(nullptr) : json(bytes)}
|
||||||
|
};
|
||||||
|
token_top_logprobs.push_back(entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
|
// Get main token logprob
|
||||||
out.push_back(json {
|
float main_logprob = (!prob.probs.empty() && prob.probs[0].prob > 0)
|
||||||
{"content", tok_str},
|
? std::log(prob.probs[0].prob)
|
||||||
{"probs", probs_for_token},
|
: -9999.0f;
|
||||||
});
|
|
||||||
|
// Get UTF-8 bytes for the main token
|
||||||
|
std::vector<int> main_bytes;
|
||||||
|
for (unsigned char c : token_str) {
|
||||||
|
main_bytes.push_back(static_cast<int>(c));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add token info to content array
|
||||||
|
json token_info = {
|
||||||
|
{"token", token_str},
|
||||||
|
{"logprob", main_logprob},
|
||||||
|
{"bytes", main_bytes.empty() ? json(nullptr) : json(main_bytes)},
|
||||||
|
{"top_logprobs", token_top_logprobs}
|
||||||
|
};
|
||||||
|
content.push_back(token_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
logprobs = {
|
||||||
|
{"content", content},
|
||||||
|
{"refusal", nullptr} // Add refusal field as null since we don't implement content filtering
|
||||||
|
};
|
||||||
|
|
||||||
|
return logprobs;
|
||||||
}
|
}
|
||||||
|
|
||||||
return out;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
|
static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
|
||||||
|
@ -540,7 +617,8 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
||||||
static json oaicompat_completion_params_parse(
|
static json oaicompat_completion_params_parse(
|
||||||
const struct llama_model * model,
|
const struct llama_model * model,
|
||||||
const json & body, /* openai api json semantics */
|
const json & body, /* openai api json semantics */
|
||||||
const std::string & chat_template) {
|
const std::string & chat_template
|
||||||
|
) {
|
||||||
json llama_params;
|
json llama_params;
|
||||||
|
|
||||||
llama_params["__oaicompat"] = true;
|
llama_params["__oaicompat"] = true;
|
||||||
|
@ -604,43 +682,71 @@ static json oaicompat_completion_params_parse(
|
||||||
return llama_params;
|
return llama_params;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
|
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool legacy_format = false) {
|
||||||
bool stopped_word = result.count("stopped_word") != 0;
|
bool stopped_word = result.count("stopped_word") != 0;
|
||||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
bool stopped_eos = json_value(result, "stopped_eos", false);
|
||||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
std::string content = json_value(result, "content", std::string(""));
|
||||||
|
bool truncated = json_value(result, "truncated", false);
|
||||||
|
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
if (stopped_word || stopped_eos) {
|
if (stopped_word || stopped_eos) {
|
||||||
finish_reason = "stop";
|
finish_reason = "stop";
|
||||||
}
|
}
|
||||||
|
|
||||||
json choices =
|
json choices;
|
||||||
streaming ? json::array({json{{"finish_reason", finish_reason},
|
// Use the pre-formatted logprobs directly
|
||||||
{"index", 0},
|
json logprobs = result.contains("completion_probabilities") ?
|
||||||
{"delta", json::object()}}})
|
result["completion_probabilities"] : nullptr;
|
||||||
: json::array({json{{"finish_reason", finish_reason},
|
if (legacy_format) {
|
||||||
{"index", 0},
|
|
||||||
{"message", json{{"content", content},
|
choices = json::array({json{
|
||||||
{"role", "assistant"}}}}});
|
{"finish_reason", finish_reason},
|
||||||
|
{"index", 0},
|
||||||
|
{"logprobs", logprobs},
|
||||||
|
{"text", content}
|
||||||
|
}});
|
||||||
|
} else {
|
||||||
|
// Format for chat completions endpoint
|
||||||
|
choices = streaming ?
|
||||||
|
json::array({json{
|
||||||
|
{"finish_reason", finish_reason},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json::object()}
|
||||||
|
}}) :
|
||||||
|
json::array({json{
|
||||||
|
{"finish_reason", finish_reason},
|
||||||
|
{"index", 0},
|
||||||
|
{"message", json{
|
||||||
|
{"content", content},
|
||||||
|
{"role", "assistant"}
|
||||||
|
}}
|
||||||
|
}});
|
||||||
|
}
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
|
|
||||||
json res = json {
|
json res = json {
|
||||||
{"choices", choices},
|
{"choices", choices},
|
||||||
{"created", t},
|
{"created", t},
|
||||||
{"model",
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
{"object", legacy_format ? "text_completion" :
|
||||||
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
|
(streaming ? "chat.completion.chunk" : "chat.completion")},
|
||||||
{"usage", json {
|
{"usage", json {
|
||||||
{"completion_tokens", num_tokens_predicted},
|
{"completion_tokens", num_tokens_predicted},
|
||||||
{"prompt_tokens", num_prompt_tokens},
|
{"prompt_tokens", num_prompt_tokens},
|
||||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
||||||
}},
|
}},
|
||||||
{"id", completion_id}
|
{"id", completion_id},
|
||||||
|
{"truncated", truncated}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Add system_fingerprint if provided
|
||||||
|
if (result.contains("system_fingerprint")) {
|
||||||
|
res["system_fingerprint"] = result["system_fingerprint"];
|
||||||
|
}
|
||||||
|
|
||||||
// extra fields for debugging purposes
|
// extra fields for debugging purposes
|
||||||
if (verbose) {
|
if (verbose) {
|
||||||
res["__verbose"] = result;
|
res["__verbose"] = result;
|
||||||
|
@ -658,105 +764,127 @@ static json format_final_response_oaicompat(const json & request, const json & r
|
||||||
}
|
}
|
||||||
|
|
||||||
// return value is vector as there is one case where we might need to generate two responses
|
// return value is vector as there is one case where we might need to generate two responses
|
||||||
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
|
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool legacy_format = false) {
|
||||||
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
// Early return if required fields are missing
|
||||||
return std::vector<json>({result});
|
// if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
|
||||||
}
|
// return std::vector<json>({result});
|
||||||
|
// }
|
||||||
|
|
||||||
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
|
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
|
||||||
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
||||||
|
std::string content = json_value(result, "content", std::string(""));
|
||||||
|
std::time_t t = std::time(0);
|
||||||
|
|
||||||
bool stopped_word = json_value(result, "stopped_word", false);
|
// Determine finish reason
|
||||||
bool stopped_eos = json_value(result, "stopped_eos", false);
|
|
||||||
bool stopped_limit = json_value(result, "stopped_limit", false);
|
|
||||||
std::string content = json_value(result, "content", std::string(""));
|
|
||||||
|
|
||||||
std::string finish_reason;
|
std::string finish_reason;
|
||||||
if (stopped_word || stopped_eos) {
|
if (json_value(result, "stopped_word", false) || json_value(result, "stopped_eos", false)) {
|
||||||
finish_reason = "stop";
|
finish_reason = "stop";
|
||||||
}
|
}
|
||||||
if (stopped_limit) {
|
if (json_value(result, "stopped_limit", false)) {
|
||||||
finish_reason = "length";
|
finish_reason = "length";
|
||||||
}
|
}
|
||||||
|
|
||||||
std::time_t t = std::time(0);
|
|
||||||
|
|
||||||
json choices;
|
json choices;
|
||||||
|
|
||||||
if (!finish_reason.empty()) {
|
if (!finish_reason.empty()) {
|
||||||
choices = json::array({json{{"finish_reason", finish_reason},
|
// Final message with finish reason
|
||||||
{"index", 0},
|
if (legacy_format) {
|
||||||
{"delta", json::object()}}});
|
choices = json::array({json{
|
||||||
} else {
|
{"finish_reason", finish_reason},
|
||||||
if (first) {
|
{"index", 0},
|
||||||
if (content.empty()) {
|
{"logprobs", result.contains("completion_probabilities") ?
|
||||||
choices = json::array({json{{"finish_reason", nullptr},
|
result["completion_probabilities"] : nullptr},
|
||||||
{"index", 0},
|
{"text", content}
|
||||||
{"delta", json{{"role", "assistant"}}}}});
|
}});
|
||||||
} else {
|
|
||||||
// We have to send this as two updates to conform to openai behavior
|
|
||||||
json initial_ret = json{{"choices", json::array({json{
|
|
||||||
{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json{
|
|
||||||
{"role", "assistant"}
|
|
||||||
}}}})},
|
|
||||||
{"created", t},
|
|
||||||
{"id", completion_id},
|
|
||||||
{"model", modelname},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
json second_ret = json{
|
|
||||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json{
|
|
||||||
{"content", content}}}
|
|
||||||
}})},
|
|
||||||
{"created", t},
|
|
||||||
{"id", completion_id},
|
|
||||||
{"model", modelname},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
return std::vector<json>({initial_ret, second_ret});
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Some idiosyncrasy in task processing logic makes several trailing calls
|
choices = json::array({json{
|
||||||
// with empty content, we ignore these at the calee site.
|
{"finish_reason", finish_reason},
|
||||||
if (content.empty()) {
|
{"index", 0},
|
||||||
return std::vector<json>({json::object()});
|
{"delta", json::object()}
|
||||||
}
|
}});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Content message
|
||||||
|
if (legacy_format) {
|
||||||
choices = json::array({json{
|
choices = json::array({json{
|
||||||
{"finish_reason", nullptr},
|
{"finish_reason", nullptr},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"delta",
|
{"logprobs", result.contains("completion_probabilities") ?
|
||||||
json{
|
result["completion_probabilities"] : nullptr},
|
||||||
{"content", content},
|
{"text", content}
|
||||||
}},
|
|
||||||
}});
|
}});
|
||||||
|
} else {
|
||||||
|
if (first) {
|
||||||
|
if (content.empty()) {
|
||||||
|
choices = json::array({json{
|
||||||
|
{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json{{"role", "assistant"}}}
|
||||||
|
}});
|
||||||
|
} else {
|
||||||
|
if (content.empty()) {
|
||||||
|
return std::vector<json>({json::object()});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split into two messages for first content in chat mode
|
||||||
|
json initial_ret = json{
|
||||||
|
{"choices", json::array({json{
|
||||||
|
{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json{{"role", "assistant"}}}
|
||||||
|
}})},
|
||||||
|
{"created", t},
|
||||||
|
{"id", completion_id},
|
||||||
|
{"model", modelname},
|
||||||
|
{"object", "chat.completion.chunk"}
|
||||||
|
};
|
||||||
|
|
||||||
|
json second_ret = json{
|
||||||
|
{"choices", json::array({json{
|
||||||
|
{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json{{"content", content}}}
|
||||||
|
}})},
|
||||||
|
{"created", t},
|
||||||
|
{"id", completion_id},
|
||||||
|
{"model", modelname},
|
||||||
|
{"object", "chat.completion.chunk"}
|
||||||
|
};
|
||||||
|
|
||||||
|
return std::vector<json>({initial_ret, second_ret});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
choices = json::array({json{
|
||||||
|
{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json{{"content", content}}}
|
||||||
|
}});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
json ret = json {
|
// Construct the response
|
||||||
|
json ret = json{
|
||||||
{"choices", choices},
|
{"choices", choices},
|
||||||
{"created", t},
|
{"created", t},
|
||||||
{"id", completion_id},
|
{"id", completion_id},
|
||||||
{"model", modelname},
|
{"model", modelname},
|
||||||
{"object", "chat.completion.chunk"}
|
{"object", legacy_format ? "text_completion" : "chat.completion.chunk"}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Add timings if present
|
||||||
if (result.contains("timings")) {
|
if (result.contains("timings")) {
|
||||||
ret.push_back({"timings", json_value(result, "timings", json::object())});
|
ret["timings"] = json_value(result, "timings", json::object());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add usage statistics for final messages
|
||||||
if (!finish_reason.empty()) {
|
if (!finish_reason.empty()) {
|
||||||
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
|
||||||
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
|
||||||
ret.push_back({"usage", json {
|
ret["usage"] = json{
|
||||||
{"completion_tokens", num_tokens_predicted},
|
{"completion_tokens", num_tokens_predicted},
|
||||||
{"prompt_tokens", num_prompt_tokens},
|
{"prompt_tokens", num_prompt_tokens},
|
||||||
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
|
||||||
}});
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::vector<json>({ret});
|
return std::vector<json>({ret});
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue