server: add OpenAI compatible response format for /completions with backward compatibility

This commit is contained in:
Oren Collaco 2024-12-03 16:59:29 -07:00
parent 3b4f2e33e2
commit 938dbd4d6e
9 changed files with 470 additions and 186 deletions

View file

@ -924,7 +924,7 @@ struct server_context {
slot.params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
slot.params.sampling.penalize_nl = json_value(data, "penalize_nl", defaults.sampling.penalize_nl);
slot.params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
slot.params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
slot.params.sampling.n_probs = json_value(data, "n_probs", json_value(data, "logprobs", defaults.sampling.n_probs));
slot.params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
@ -1340,7 +1340,8 @@ struct server_context {
}
slot.n_sent_token_probs = probs_stop_pos;
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
// TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output, true);
}
if (slot.oaicompat) {
@ -1379,7 +1380,7 @@ struct server_context {
{"timings", slot.get_formated_timings()},
{"index", slot.index},
};
if (slot.params.sampling.n_probs > 0) {
std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) {
@ -1395,7 +1396,8 @@ struct server_context {
slot.generated_token_probs.end());
}
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
// TODO: bool to determine if we are using the new format (/chat/completions) or the legacy format (/completions) for logprobs
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs, true);
}
if (slot.oaicompat) {
@ -2901,31 +2903,63 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }});
};
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
const auto handle_completions_generic = [&ctx_server, &params, &res_error, &res_ok, verbose](server_task_inf_type inf_type, json & data, httplib::Response & res, bool is_chat = false) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
// Parse data for /chat/completions format if needed
if (is_chat) {
data = oaicompat_completion_params_parse(ctx_server.model, data, params.chat_template);
}
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
bool oai_compat = json_value(data, "oai_compat", true);
const auto task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
if (results.size() == 1) {
// single result
res_ok(res, results[0].data);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) {
if (is_chat) {
// multitask is never supported in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id,
/*.streaming =*/ false, verbose, /*.legacy_format =*/ !is_chat);
res_ok(res, result_oai);
} else {
if (results.size() == 1) {
// single result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id,
/*.streaming =*/ false, verbose, /*.legacy_format =*/ true);
res_ok(res, result_oai);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & result : results) {
arr.push_back(format_final_response_oaicompat(data, result.data, completion_id,
/*.streaming =*/ false, verbose, /*.legacy_format =*/ true));
}
res_ok(res, arr);
}
}
}
else{
if (results.size() == 1) {
// single result
res_ok(res, results[0].data);
} else {
// multiple results (multitask)
json arr = json::array();
for (const auto & res : results) {
arr.push_back(res.data);
}
res_ok(res, arr);
}
res_ok(res, arr);
}
}, [&](const json & error_data) {
res_error(res, error_data);
@ -2933,14 +2967,35 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id, is_chat, inf_type, oai_compat](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
return server_sent_event(sink, "data", result.data);
if (inf_type == SERVER_TASK_INF_TYPE_COMPLETION && (oai_compat || is_chat)) {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id, !is_chat);
for (auto & event_data : result_array) {
if (event_data.empty()) {
continue; // skip the stop token
}
if (!server_sent_event(sink, "data", event_data)) {
return false; // connection is closed
}
}
return true; // ok
}
return server_sent_event(sink, "data", result.data);
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
});
if (is_chat) {
static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size());
}
sink.done();
return false;
return true;
};
auto on_complete = [task_ids, &ctx_server] (bool) {
@ -2953,7 +3008,12 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, false);
};
const auto handle_chat_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res, true);
};
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@ -3006,63 +3066,6 @@ int main(int argc, char ** argv) {
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
};
// TODO: maybe merge this function with "handle_completions_generic"
const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
ctx_server.queue_results.add_waiting_tasks(tasks);
ctx_server.queue_tasks.post(tasks);
bool stream = json_value(data, "stream", false);
const auto task_ids = server_task::get_list_id(tasks);
const auto completion_id = gen_chatcmplid();
if (!stream) {
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
// multitask is never support in chat completion, there is only one result
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
res_ok(res, result_oai);
}, [&](const json & error_data) {
res_error(res, error_data);
});
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
} else {
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto & event_data : result_array) {
if (event_data.empty()) {
continue; // skip the stop token
}
if (!server_sent_event(sink, "data", event_data)) {
return false; // connection is closed
}
}
return true; // ok
}, [&](const json & error_data) {
server_sent_event(sink, "error", error_data);
});
static const std::string ev_done = "data: [DONE]\n\n";
sink.write(ev_done.data(), ev_done.size());
sink.done();
return true;
};
auto on_complete = [task_ids, &ctx_server] (bool) {
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
json models = {
{"object", "list"},

View file

@ -40,9 +40,19 @@ def test_load_split_model():
server.model_alias = "tinyllama-split"
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 16,
"max_tokens": 16,
"prompt": "Hello",
"temperature": 0.0,
})
assert res.status_code == 200
assert match_regex("(little|girl)+", res.body["content"])
# Verify response structure
assert "id" in res.body
assert "object" in res.body
assert "created" in res.body
assert "model" in res.body
assert "choices" in res.body
assert isinstance(res.body["choices"], list)
assert len(res.body["choices"]) > 0
assert "text" in res.body["choices"][0]
# Verify the actual content
assert match_regex("(little|girl)+", res.body["choices"][0]["text"])

View file

@ -18,13 +18,13 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"max_tokens": n_predict,
"prompt": prompt,
"oai_compat": False,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == n_prompt
assert res.body["timings"]["predicted_n"] == n_predicted
assert res.body["truncated"] == truncated
assert match_regex(re_content, res.body["content"])
@ -36,16 +36,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"n_predict": n_predict,
"max_tokens": n_predict,
"prompt": prompt,
"stream": True,
"oai_compat": False,
})
content = ""
for data in res:
if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated
assert match_regex(re_content, content)
else:
content += data["content"]
@ -63,6 +63,7 @@ def test_consistent_result_same_seed(n_slots: int):
"seed": 42,
"temperature": 1.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
"oai_compat": False,
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
@ -81,6 +82,7 @@ def test_different_result_different_seed(n_slots: int):
"seed": seed,
"temperature": 1.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
"oai_compat": False,
})
if last_res is not None:
assert res.body["content"] != last_res.body["content"]
@ -100,6 +102,7 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
"seed": 42,
"temperature": temperature,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
"oai_compat": False,
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
@ -115,12 +118,14 @@ def test_cache_vs_nocache_prompt():
"seed": 42,
"temperature": 1.0,
"cache_prompt": True,
"oai_compat": False,
})
res_no_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
"oai_compat": False,
})
assert res_cache.body["content"] == res_no_cache.body["content"]
@ -140,6 +145,7 @@ def test_completion_with_tokens_input():
# single completion
res = server.make_request("POST", "/completion", data={
"prompt": tokens,
"oai_compat": False,
})
assert res.status_code == 200
assert type(res.body["content"]) == str
@ -147,6 +153,7 @@ def test_completion_with_tokens_input():
# batch completion
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, tokens],
"oai_compat": False,
})
assert res.status_code == 200
assert type(res.body) == list
@ -156,6 +163,7 @@ def test_completion_with_tokens_input():
# mixed string and tokens
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, prompt_str],
"oai_compat": False,
})
assert res.status_code == 200
assert type(res.body) == list
@ -165,6 +173,7 @@ def test_completion_with_tokens_input():
# mixed string and tokens in one sequence
res = server.make_request("POST", "/completion", data={
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
"oai_compat": False,
})
assert res.status_code == 200
assert type(res.body["content"]) == str
@ -208,6 +217,7 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
"prompt": prompt,
"seed": 42,
"temperature": 1.0,
"oai_compat": False,
})))
tasks.append((check_slots_status, ()))
results = parallel_function_calls(tasks)
@ -221,3 +231,122 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
assert len(res.body["content"]) > 10
# FIXME: the result is not deterministic when using other slot than slot 0
# assert match_regex(re_content, res.body["content"])
# OpenAI legacy completion endpoint tests
@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [
("I believe the meaning of life is", 8, "going to bed", 18, 8),
("Write a joke about", 16, "Why did the AI", 14, 16),
])
def test_completion_openai(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int):
global server
server.start()
# Test non-streaming response
res = server.make_request("POST", "/completions", data={
"model": "local-model",
"prompt": prompt,
"max_tokens": n_predict,
"logprobs": 3,
"echo": True
})
assert res.status_code == 200
assert res.body["object"] == "text_completion"
assert isinstance(res.body["id"], str)
assert isinstance(res.body["created"], int)
assert res.body["model"] == "local-model"
# Check choices array
assert len(res.body["choices"]) == 1
choice = res.body["choices"][0]
assert choice["index"] == 0
assert isinstance(choice["text"], str)
assert choice["finish_reason"] in ["stop", "length"]
# Check logprobs
assert choice["logprobs"] is not None
assert "tokens" in choice["logprobs"]
assert "token_logprobs" in choice["logprobs"]
assert "top_logprobs" in choice["logprobs"]
assert len(choice["logprobs"]["top_logprobs"]) == len(choice["logprobs"]["tokens"])
# Check usage statistics
assert "usage" in res.body
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted
assert res.body["usage"]["total_tokens"] == n_prompt + n_predicted
@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [
("I believe the meaning of life is", 8, "going to bed", 18, 8),
("Write a joke about", 16, "Why did the AI", 14, 16),
])
def test_completion_openai_stream(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int):
global server
server.start()
res = server.make_stream_request("POST", "/v1/completions", data={
"prompt": prompt,
"max_tokens": n_predict,
"stream": True,
})
collected_text = ""
is_first_chunk = True
for data in res:
assert "id" in data
assert data["object"] == "text_completion"
assert isinstance(data["created"], int)
assert len(data["choices"]) == 1
choice = data["choices"][0]
assert choice["index"] == 0
assert isinstance(choice["text"], str)
collected_text += choice["text"]
if is_first_chunk:
# First chunk should have model info
is_first_chunk = False
if choice["finish_reason"] is not None:
# This is the last chunk
assert choice["finish_reason"] in ["stop", "length"]
assert "usage" in data
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
assert data["usage"]["total_tokens"] == n_prompt + n_predicted
@pytest.mark.parametrize("prompt,n_predict,expected_text,n_prompt,n_predicted", [
("I believe the meaning of life is", 8, "going to bed", 18, 8),
("Write a joke about", 16, "Why did the AI", 14, 16),
])
def test_completion_openai_no_logprobs(prompt: str, n_predict: int, expected_text: str, n_prompt: int, n_predicted: int):
global server
server.start()
# Test non-streaming response
res = server.make_request("POST", "/completions", data={
"prompt": prompt,
"max_tokens": n_predict,
"echo": True
})
assert res.status_code == 200
assert res.body["object"] == "text_completion"
assert isinstance(res.body["id"], str)
assert isinstance(res.body["created"], int)
# Check choices array
assert len(res.body["choices"]) == 1
choice = res.body["choices"][0]
assert choice["index"] == 0
assert isinstance(choice["text"], str)
assert choice["finish_reason"] in ["stop", "length"]
# Verify logprobs is None when not requested
assert choice["logprobs"] is None
# Check usage statistics
assert "usage" in res.body
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted
assert res.body["usage"]["total_tokens"] == n_prompt + n_predicted

View file

@ -29,6 +29,7 @@ def test_ctx_shift_enabled():
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
"oai_compat": False,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == 109
@ -48,6 +49,7 @@ def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, tr
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": "Hi how are you",
"oai_compat": False,
})
assert res.status_code == 200
assert res.body["timings"]["predicted_n"] == n_token_output
@ -61,6 +63,7 @@ def test_ctx_shift_disabled_long_prompt():
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
"oai_compat": False,
})
assert res.status_code != 200
assert "error" in res.body

View file

@ -36,6 +36,7 @@ def test_lora(scale: float, re_content: str):
assert res_lora_control.status_code == 200
res = server.make_request("POST", "/completion", data={
"prompt": "Look in thy glass",
"oai_compat": False,
})
assert res.status_code == 200
assert match_regex(re_content, res.body["content"])

View file

@ -41,6 +41,7 @@ def test_correct_api_key():
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
"oai_compat": False,
}, headers={
"Authorization": f"Bearer {TEST_API_KEY}",
})

View file

@ -20,6 +20,7 @@ def test_slot_save_restore():
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
"oai_compat": False,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
@ -37,6 +38,7 @@ def test_slot_save_restore():
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
"oai_compat": False,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
@ -54,6 +56,7 @@ def test_slot_save_restore():
"prompt": "What is the capital of Germany?",
"id_slot": 0,
"cache_prompt": True,
"oai_compat": False,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
@ -64,6 +67,7 @@ def test_slot_save_restore():
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
"oai_compat": False,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
@ -78,6 +82,7 @@ def test_slot_erase():
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
"oai_compat": False,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
@ -94,5 +99,5 @@ def test_slot_erase():
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert match_regex("(Whiskers|Flana)+", res.body["choices"][0]["text"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed

View file

@ -37,6 +37,7 @@ def test_with_and_without_draft():
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
"oai_compat": False,
})
assert res.status_code == 200
content_no_draft = res.body["content"]
@ -49,6 +50,7 @@ def test_with_and_without_draft():
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
"oai_compat": False,
})
assert res.status_code == 200
content_draft = res.body["content"]
@ -75,6 +77,7 @@ def test_different_draft_min_draft_max():
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
"oai_compat": False,
})
assert res.status_code == 200
if last_content is not None:
@ -96,6 +99,7 @@ def test_multi_requests_parallel(n_slots: int, n_requests: int):
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
"oai_compat": False,
})))
results = parallel_function_calls(tasks)
for res in results:

View file

@ -498,28 +498,105 @@ struct completion_token_output {
};
// convert a vector of completion_token_output to json
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
json out = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
probs_for_token.push_back(json {
{"tok_str", tok_str},
{"prob", p.prob},
});
static json probs_vector_to_json(llama_context * ctx, const std::vector<completion_token_output> & probs, bool legacy_format = true) {
if (legacy_format) {
// Legacy format (text_completion endpoint)
json logprobs;
std::vector<std::string> tokens;
std::vector<json> token_logprobs; // Changed to json to allow null values
std::vector<json> top_logprobs; // Changed to allow null values
std::vector<int> text_offset;
int current_offset = 0;
for (const auto & prob : probs) {
std::string token_str = tokens_to_output_formatted_string(ctx, prob.tok);
tokens.push_back(token_str);
text_offset.push_back(current_offset);
// Handle token logprobs
if (!prob.probs.empty() && prob.probs[0].prob > 0) {
token_logprobs.push_back(std::log(prob.probs[0].prob));
} else {
token_logprobs.push_back(nullptr);
}
// Handle top logprobs
json token_top_logprobs = json::object();
for (const auto & p : prob.probs) {
if (p.prob > 0) {
token_top_logprobs[tokens_to_output_formatted_string(ctx, p.tok)] = std::log(p.prob);
}
}
top_logprobs.push_back(token_top_logprobs.empty() ? nullptr : token_top_logprobs);
current_offset += token_str.length();
}
logprobs = {
{"tokens", tokens},
{"token_logprobs", token_logprobs},
{"top_logprobs", top_logprobs},
{"text_offset", text_offset}
};
return logprobs;
} else {
// New format (GPT-4 style)
json logprobs;
std::vector<json> content;
for (const auto & prob : probs) {
std::string token_str = tokens_to_output_formatted_string(ctx, prob.tok);
// Create top_logprobs array for this token
json token_top_logprobs = json::array();
for (const auto & p : prob.probs) {
if (p.prob > 0) {
// Get UTF-8 bytes for the token
std::vector<int> bytes;
std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
for (unsigned char c : tok_str) {
bytes.push_back(static_cast<int>(c));
}
json entry = {
{"token", tok_str},
{"logprob", std::log(p.prob)},
{"bytes", bytes.empty() ? json(nullptr) : json(bytes)}
};
token_top_logprobs.push_back(entry);
}
}
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
out.push_back(json {
{"content", tok_str},
{"probs", probs_for_token},
});
// Get main token logprob
float main_logprob = (!prob.probs.empty() && prob.probs[0].prob > 0)
? std::log(prob.probs[0].prob)
: -9999.0f;
// Get UTF-8 bytes for the main token
std::vector<int> main_bytes;
for (unsigned char c : token_str) {
main_bytes.push_back(static_cast<int>(c));
}
// Add token info to content array
json token_info = {
{"token", token_str},
{"logprob", main_logprob},
{"bytes", main_bytes.empty() ? json(nullptr) : json(main_bytes)},
{"top_logprobs", token_top_logprobs}
};
content.push_back(token_info);
}
logprobs = {
{"content", content},
{"refusal", nullptr} // Add refusal field as null since we don't implement content filtering
};
return logprobs;
}
return out;
}
static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
@ -540,7 +617,8 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
static json oaicompat_completion_params_parse(
const struct llama_model * model,
const json & body, /* openai api json semantics */
const std::string & chat_template) {
const std::string & chat_template
) {
json llama_params;
llama_params["__oaicompat"] = true;
@ -604,43 +682,71 @@ static json oaicompat_completion_params_parse(
return llama_params;
}
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false, bool legacy_format = false) {
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
bool truncated = json_value(result, "truncated", false);
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
json choices =
streaming ? json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"role", "assistant"}}}}});
json choices;
// Use the pre-formatted logprobs directly
json logprobs = result.contains("completion_probabilities") ?
result["completion_probabilities"] : nullptr;
if (legacy_format) {
choices = json::array({json{
{"finish_reason", finish_reason},
{"index", 0},
{"logprobs", logprobs},
{"text", content}
}});
} else {
// Format for chat completions endpoint
choices = streaming ?
json::array({json{
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}
}}) :
json::array({json{
{"finish_reason", finish_reason},
{"index", 0},
{"message", json{
{"content", content},
{"role", "assistant"}
}}
}});
}
std::time_t t = std::time(0);
json res = json {
{"choices", choices},
{"created", t},
{"model",
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", legacy_format ? "text_completion" :
(streaming ? "chat.completion.chunk" : "chat.completion")},
{"usage", json {
{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
}},
{"id", completion_id}
{"id", completion_id},
{"truncated", truncated}
};
// Add system_fingerprint if provided
if (result.contains("system_fingerprint")) {
res["system_fingerprint"] = result["system_fingerprint"];
}
// extra fields for debugging purposes
if (verbose) {
res["__verbose"] = result;
@ -658,105 +764,127 @@ static json format_final_response_oaicompat(const json & request, const json & r
}
// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result});
}
static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id, bool legacy_format = false) {
// Early return if required fields are missing
// if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
// return std::vector<json>({result});
// }
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
std::string content = json_value(result, "content", std::string(""));
std::time_t t = std::time(0);
bool stopped_word = json_value(result, "stopped_word", false);
bool stopped_eos = json_value(result, "stopped_eos", false);
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
// Determine finish reason
std::string finish_reason;
if (stopped_word || stopped_eos) {
if (json_value(result, "stopped_word", false) || json_value(result, "stopped_eos", false)) {
finish_reason = "stop";
}
if (stopped_limit) {
if (json_value(result, "stopped_limit", false)) {
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}};
return std::vector<json>({initial_ret, second_ret});
}
// Final message with finish reason
if (legacy_format) {
choices = json::array({json{
{"finish_reason", finish_reason},
{"index", 0},
{"logprobs", result.contains("completion_probabilities") ?
result["completion_probabilities"] : nullptr},
{"text", content}
}});
} else {
// Some idiosyncrasy in task processing logic makes several trailing calls
// with empty content, we ignore these at the calee site.
if (content.empty()) {
return std::vector<json>({json::object()});
}
choices = json::array({json{
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}
}});
}
} else {
// Content message
if (legacy_format) {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
{"logprobs", result.contains("completion_probabilities") ?
result["completion_probabilities"] : nullptr},
{"text", content}
}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}
}});
} else {
if (content.empty()) {
return std::vector<json>({json::object()});
}
// Split into two messages for first content in chat mode
json initial_ret = json{
{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}
}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
json second_ret = json{
{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"content", content}}}
}})},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
};
return std::vector<json>({initial_ret, second_ret});
}
} else {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"content", content}}}
}});
}
}
}
json ret = json {
// Construct the response
json ret = json{
{"choices", choices},
{"created", t},
{"id", completion_id},
{"model", modelname},
{"object", "chat.completion.chunk"}
{"id", completion_id},
{"model", modelname},
{"object", legacy_format ? "text_completion" : "chat.completion.chunk"}
};
// Add timings if present
if (result.contains("timings")) {
ret.push_back({"timings", json_value(result, "timings", json::object())});
ret["timings"] = json_value(result, "timings", json::object());
}
// Add usage statistics for final messages
if (!finish_reason.empty()) {
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
ret.push_back({"usage", json {
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
ret["usage"] = json{
{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
}});
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}
};
}
return std::vector<json>({ret});