add --multi-token-probs
This commit is contained in:
parent
06bb38e75d
commit
196e237e09
6 changed files with 22 additions and 0 deletions
|
@ -1057,6 +1057,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
|
params.sampling.grammar = json_schema_to_grammar(json::parse(value));
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"-mtp", "--multi-token-probs"},
|
||||||
|
string_format(
|
||||||
|
"allow getting probabilities for multiple tokens. note: this will slow down the generation speed (default: %s)",
|
||||||
|
params.sampling.multi_token_probs ? "enabled" : "disabled"
|
||||||
|
),
|
||||||
|
[](common_params & params) {
|
||||||
|
params.sampling.multi_token_probs = true;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MULTI_TOKEN_PROBS"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"--pooling"}, "{none,mean,cls,last,rank}",
|
{"--pooling"}, "{none,mean,cls,last,rank}",
|
||||||
"pooling type for embeddings, use model default if unspecified",
|
"pooling type for embeddings, use model default if unspecified",
|
||||||
|
|
|
@ -134,6 +134,7 @@ struct common_params_sampling {
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
bool timing_per_token = false;
|
bool timing_per_token = false;
|
||||||
|
bool multi_token_probs = false; // output probabilities for multiple tokens (when n_probs > 0)
|
||||||
|
|
||||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||||
|
|
||||||
|
|
|
@ -239,6 +239,10 @@ struct server_task {
|
||||||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||||
|
|
||||||
|
if (!params_base.sampling.multi_token_probs && params.n_predict > 1 && params.sampling.n_probs > 0) {
|
||||||
|
throw std::runtime_error("For performance reason, n_probs with n_predict > 1 is not allowed. To enable this, start the server with --multi-token-probs");
|
||||||
|
}
|
||||||
|
|
||||||
if (params.sampling.dry_base < 1.0f) {
|
if (params.sampling.dry_base < 1.0f) {
|
||||||
params.sampling.dry_base = defaults.sampling.dry_base;
|
params.sampling.dry_base = defaults.sampling.dry_base;
|
||||||
}
|
}
|
||||||
|
|
|
@ -166,6 +166,7 @@ def test_chat_completion_with_timings_per_token():
|
||||||
|
|
||||||
def test_logprobs():
|
def test_logprobs():
|
||||||
global server
|
global server
|
||||||
|
server.multi_token_probs = True
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||||
res = client.chat.completions.create(
|
res = client.chat.completions.create(
|
||||||
|
@ -193,6 +194,7 @@ def test_logprobs():
|
||||||
|
|
||||||
def test_logprobs_stream():
|
def test_logprobs_stream():
|
||||||
global server
|
global server
|
||||||
|
server.multi_token_probs = True
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
||||||
res = client.chat.completions.create(
|
res = client.chat.completions.create(
|
||||||
|
|
|
@ -249,6 +249,7 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
||||||
|
|
||||||
def test_n_probs():
|
def test_n_probs():
|
||||||
global server
|
global server
|
||||||
|
server.multi_token_probs = True
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/completion", data={
|
res = server.make_request("POST", "/completion", data={
|
||||||
"prompt": "I believe the meaning of life is",
|
"prompt": "I believe the meaning of life is",
|
||||||
|
@ -274,6 +275,7 @@ def test_n_probs():
|
||||||
|
|
||||||
def test_n_probs_stream():
|
def test_n_probs_stream():
|
||||||
global server
|
global server
|
||||||
|
server.multi_token_probs = True
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_stream_request("POST", "/completion", data={
|
res = server.make_stream_request("POST", "/completion", data={
|
||||||
"prompt": "I believe the meaning of life is",
|
"prompt": "I believe the meaning of life is",
|
||||||
|
|
|
@ -73,6 +73,7 @@ class ServerProcess:
|
||||||
draft_min: int | None = None
|
draft_min: int | None = None
|
||||||
draft_max: int | None = None
|
draft_max: int | None = None
|
||||||
no_webui: bool | None = None
|
no_webui: bool | None = None
|
||||||
|
multi_token_probs: bool | None = None
|
||||||
|
|
||||||
# session variables
|
# session variables
|
||||||
process: subprocess.Popen | None = None
|
process: subprocess.Popen | None = None
|
||||||
|
@ -161,6 +162,8 @@ class ServerProcess:
|
||||||
server_args.extend(["--draft-min", self.draft_min])
|
server_args.extend(["--draft-min", self.draft_min])
|
||||||
if self.no_webui:
|
if self.no_webui:
|
||||||
server_args.append("--no-webui")
|
server_args.append("--no-webui")
|
||||||
|
if self.multi_token_probs:
|
||||||
|
server_args.append("--multi-token-probs")
|
||||||
|
|
||||||
args = [str(arg) for arg in [server_path, *server_args]]
|
args = [str(arg) for arg in [server_path, *server_args]]
|
||||||
print(f"bench: starting server with: {' '.join(args)}")
|
print(f"bench: starting server with: {' '.join(args)}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue