From 8734df73d9a470181ba82b5932b2980e35972fb9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 18 Dec 2024 17:15:15 +0100 Subject: [PATCH] remove --multi-token-probs --- common/arg.cpp | 10 ---------- common/common.h | 1 - examples/server/README.md | 1 - examples/server/server.cpp | 4 ---- examples/server/tests/unit/test_chat_completion.py | 2 -- examples/server/tests/unit/test_completion.py | 3 --- examples/server/tests/utils.py | 3 --- 7 files changed, 24 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index e3f546b76..3d55289c3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1085,16 +1085,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(value)); } ).set_sparam()); - add_opt(common_arg( - {"-mtp", "--multi-token-probs"}, - string_format( - "allow getting probabilities for multiple tokens. note: this will slow down the generation speed (default: %s)", - params.sampling.multi_token_probs ? "enabled" : "disabled" - ), - [](common_params & params) { - params.sampling.multi_token_probs = true; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MULTI_TOKEN_PROBS")); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", diff --git a/common/common.h b/common/common.h index 9ec4c9f4b..ec0e49f6f 100644 --- a/common/common.h +++ b/common/common.h @@ -134,7 +134,6 @@ struct common_params_sampling { bool ignore_eos = false; bool no_perf = false; // disable performance metrics bool timing_per_token = false; - bool multi_token_probs = false; // output probabilities for multiple tokens (when n_probs > 0) std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/examples/server/README.md b/examples/server/README.md index 480e40d30..73e394cfb 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -139,7 +139,6 @@ The project is under active development, and we are [looking for feedback and co | `-sp, --special` | special tokens output enabled (default: false) | | `--no-warmup` | skip warming up the model with an empty run | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | -| `-mtp, --multi-token-probs` | allow getting probabilities for multiple tokens. note: this will slow down the generation speed (default: disabled)
(env: LLAMA_ARG_MULTI_TOKEN_PROBS) | | `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 93196adcd..1b20c8e59 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -243,10 +243,6 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); - if (!params_base.sampling.multi_token_probs && params.n_predict > 1 && params.sampling.n_probs > 0) { - throw std::runtime_error("For performance reason, n_probs with n_predict > 1 is not allowed. To enable this, start the server with --multi-token-probs"); - } - // TODO: add more sanity checks for the input parameters if (params.sampling.penalty_last_n < -1) { diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 37ac11006..0fa1a17c1 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -166,7 +166,6 @@ def test_chat_completion_with_timings_per_token(): def test_logprobs(): global server - server.multi_token_probs = True server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") res = client.chat.completions.create( @@ -194,7 +193,6 @@ def test_logprobs(): def test_logprobs_stream(): global server - server.multi_token_probs = True server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") res = client.chat.completions.create( diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 78aaed052..f583737ca 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -259,7 +259,6 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int): def test_n_probs(): global server - server.multi_token_probs = True server.start() res = server.make_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", @@ -285,7 +284,6 @@ def test_n_probs(): def test_n_probs_stream(): global server - server.multi_token_probs = True server.start() res = server.make_stream_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", @@ -313,7 +311,6 @@ def test_n_probs_stream(): def test_n_probs_post_sampling(): global server - server.multi_token_probs = True server.start() res = server.make_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 12310fb22..277125e88 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -74,7 +74,6 @@ class ServerProcess: draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None - multi_token_probs: bool | None = None # session variables process: subprocess.Popen | None = None @@ -165,8 +164,6 @@ class ServerProcess: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") - if self.multi_token_probs: - server_args.append("--multi-token-probs") args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}")