From b8679c0bb5d37952163458ae699fb931de54d959 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 24 Dec 2024 16:28:44 +0100 Subject: [PATCH] change to "response_fields" --- examples/server/README.md | 2 +- examples/server/server.cpp | 10 +++++----- examples/server/tests/unit/test_completion.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 033a514ed..958d1cdd1 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -450,7 +450,7 @@ These words will not be included in the completion, so make sure to add them to `post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain. -`requested_fields`: A list of response fields, for example: `"requested_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. +`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. **Response format** diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7b277b9dc..4affc7cde 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -92,7 +92,7 @@ struct slot_params { int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit std::vector antiprompt; - std::vector requested_fields; + std::vector response_fields; bool timings_per_token = false; bool post_sampling_probs = false; bool ignore_eos = false; @@ -210,7 +210,7 @@ struct server_task { params.n_discard = json_value(data, "n_discard", defaults.n_discard); //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.requested_fields = json_value(data, "requested_fields", std::vector()); + params.response_fields = json_value(data, "response_fields", std::vector()); params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); @@ -524,7 +524,7 @@ struct server_task_result_cmpl_final : server_task_result { bool post_sampling_probs; std::vector probs_output; - std::vector requested_fields; + std::vector response_fields; slot_params generation_params; @@ -571,7 +571,7 @@ struct server_task_result_cmpl_final : server_task_result { if (!stream && !probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); } - return requested_fields.empty() ? res : json_get_nested_values(requested_fields, res); + return response_fields.empty() ? res : json_get_nested_values(response_fields, res); } json to_json_oaicompat_chat() { @@ -2066,7 +2066,7 @@ struct server_context { res->tokens = slot.generated_tokens; res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); - res->requested_fields = slot.params.requested_fields; + res->response_fields = slot.params.response_fields; res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index f7a427c33..00d5ce391 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -258,14 +258,14 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int): @pytest.mark.parametrize( - "prompt,n_predict,requested_fields", + "prompt,n_predict,response_fields", [ ("I believe the meaning of life is", 8, []), ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]), ], ) -def test_completion_requested_fields( - prompt: str, n_predict: int, requested_fields: list[str] +def test_completion_response_fields( + prompt: str, n_predict: int, response_fields: list[str] ): global server server.start() @@ -275,17 +275,17 @@ def test_completion_requested_fields( data={ "n_predict": n_predict, "prompt": prompt, - "requested_fields": requested_fields, + "response_fields": response_fields, }, ) assert res.status_code == 200 assert "content" in res.body assert len(res.body["content"]) - if len(requested_fields): + if len(response_fields): assert res.body["generation_settings/n_predict"] == n_predict assert res.body["prompt"] == " " + prompt assert isinstance(res.body["content"], str) - assert len(res.body) == len(requested_fields) + assert len(res.body) == len(response_fields) else: assert len(res.body) assert "generation_settings" in res.body