diff --git a/common/arg.cpp b/common/arg.cpp index 49af31682..622f24fb4 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1057,6 +1057,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(value)); } ).set_sparam()); + add_opt(common_arg( + {"-mtp", "--multi-token-probs"}, + string_format( + "allow getting probabilities for multiple tokens. note: this will slow down the generation speed (default: %s)", + params.sampling.multi_token_probs ? "enabled" : "disabled" + ), + [](common_params & params) { + params.sampling.multi_token_probs = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MULTI_TOKEN_PROBS")); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", diff --git a/common/common.h b/common/common.h index 95d20401d..5fcb8e506 100644 --- a/common/common.h +++ b/common/common.h @@ -134,6 +134,7 @@ struct common_params_sampling { bool ignore_eos = false; bool no_perf = false; // disable performance metrics bool timing_per_token = false; + bool multi_token_probs = false; // output probabilities for multiple tokens (when n_probs > 0) std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 95bd531b3..8f5778052 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -239,6 +239,10 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); + if (!params_base.sampling.multi_token_probs && params.n_predict > 1 && params.sampling.n_probs > 0) { + throw std::runtime_error("For performance reason, n_probs with n_predict > 1 is not allowed. To enable this, start the server with --multi-token-probs"); + } + if (params.sampling.dry_base < 1.0f) { params.sampling.dry_base = defaults.sampling.dry_base; } diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 0fa1a17c1..37ac11006 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -166,6 +166,7 @@ def test_chat_completion_with_timings_per_token(): def test_logprobs(): global server + server.multi_token_probs = True server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") res = client.chat.completions.create( @@ -193,6 +194,7 @@ def test_logprobs(): def test_logprobs_stream(): global server + server.multi_token_probs = True server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") res = client.chat.completions.create( diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 7b33ec531..ee9b9f466 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -249,6 +249,7 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int): def test_n_probs(): global server + server.multi_token_probs = True server.start() res = server.make_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", @@ -274,6 +275,7 @@ def test_n_probs(): def test_n_probs_stream(): global server + server.multi_token_probs = True server.start() res = server.make_stream_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index d988ccf5e..5221e0829 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -73,6 +73,7 @@ class ServerProcess: draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None + multi_token_probs: bool | None = None # session variables process: subprocess.Popen | None = None @@ -161,6 +162,8 @@ class ServerProcess: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.multi_token_probs: + server_args.append("--multi-token-probs") args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}")