change to "response_fields"
This commit is contained in:
parent
4cf1fef320
commit
b8679c0bb5
3 changed files with 12 additions and 12 deletions
|
@ -450,7 +450,7 @@ These words will not be included in the completion, so make sure to add them to
|
|||
|
||||
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
|
||||
|
||||
`requested_fields`: A list of response fields, for example: `"requested_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error.
|
||||
`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error.
|
||||
|
||||
**Response format**
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ struct slot_params {
|
|||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
std::vector<std::string> requested_fields;
|
||||
std::vector<std::string> response_fields;
|
||||
bool timings_per_token = false;
|
||||
bool post_sampling_probs = false;
|
||||
bool ignore_eos = false;
|
||||
|
@ -210,7 +210,7 @@ struct server_task {
|
|||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
||||
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
||||
params.requested_fields = json_value(data, "requested_fields", std::vector<std::string>());
|
||||
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
|
||||
|
||||
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
||||
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
||||
|
@ -524,7 +524,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
bool post_sampling_probs;
|
||||
std::vector<completion_token_output> probs_output;
|
||||
std::vector<std::string> requested_fields;
|
||||
std::vector<std::string> response_fields;
|
||||
|
||||
slot_params generation_params;
|
||||
|
||||
|
@ -571,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
if (!stream && !probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||
}
|
||||
return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res);
|
||||
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||
}
|
||||
|
||||
json to_json_oaicompat_chat() {
|
||||
|
@ -2066,7 +2066,7 @@ struct server_context {
|
|||
res->tokens = slot.generated_tokens;
|
||||
res->timings = slot.get_timings();
|
||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||
res->requested_fields = slot.params.requested_fields;
|
||||
res->response_fields = slot.params.response_fields;
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
|
|
|
@ -258,14 +258,14 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt,n_predict,requested_fields",
|
||||
"prompt,n_predict,response_fields",
|
||||
[
|
||||
("I believe the meaning of life is", 8, []),
|
||||
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
|
||||
],
|
||||
)
|
||||
def test_completion_requested_fields(
|
||||
prompt: str, n_predict: int, requested_fields: list[str]
|
||||
def test_completion_response_fields(
|
||||
prompt: str, n_predict: int, response_fields: list[str]
|
||||
):
|
||||
global server
|
||||
server.start()
|
||||
|
@ -275,17 +275,17 @@ def test_completion_requested_fields(
|
|||
data={
|
||||
"n_predict": n_predict,
|
||||
"prompt": prompt,
|
||||
"requested_fields": requested_fields,
|
||||
"response_fields": response_fields,
|
||||
},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert "content" in res.body
|
||||
assert len(res.body["content"])
|
||||
if len(requested_fields):
|
||||
if len(response_fields):
|
||||
assert res.body["generation_settings/n_predict"] == n_predict
|
||||
assert res.body["prompt"] == "<s> " + prompt
|
||||
assert isinstance(res.body["content"], str)
|
||||
assert len(res.body) == len(requested_fields)
|
||||
assert len(res.body) == len(response_fields)
|
||||
else:
|
||||
assert len(res.body)
|
||||
assert "generation_settings" in res.body
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue