add post_sampling_probs
option
This commit is contained in:
parent
c0cca53d85
commit
ecadd37c63
4 changed files with 150 additions and 78 deletions
|
@ -449,52 +449,56 @@ These words will not be included in the completion, so make sure to add them to
|
||||||
|
|
||||||
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
|
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
|
||||||
|
|
||||||
|
`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain.
|
||||||
|
|
||||||
**Response format**
|
**Response format**
|
||||||
|
|
||||||
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
|
||||||
|
|
||||||
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:
|
- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements:
|
||||||
|
```json
|
||||||
```json
|
{
|
||||||
{
|
"content": "<the generated completion text>",
|
||||||
"content": "<the generated completion text>",
|
"tokens": [ generated token ids if requested ],
|
||||||
"tokens": [ generated token ids if requested ],
|
|
||||||
...
|
|
||||||
"probs": [
|
|
||||||
{
|
|
||||||
"id": <token id>,
|
|
||||||
"logprob": float,
|
|
||||||
"token": "<most likely token>",
|
|
||||||
"bytes": [int, int, ...],
|
|
||||||
"top_logprobs": [
|
|
||||||
{
|
|
||||||
"id": <token id>,
|
|
||||||
"logprob": float,
|
|
||||||
"token": "<token text>",
|
|
||||||
"bytes": [int, int, ...],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": <token id>,
|
|
||||||
"logprob": float,
|
|
||||||
"token": "<token text>",
|
|
||||||
"bytes": [int, int, ...],
|
|
||||||
},
|
|
||||||
...
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": <token id>,
|
|
||||||
"logprob": float,
|
|
||||||
"token": "<most likely token>",
|
|
||||||
"bytes": [int, int, ...],
|
|
||||||
"top_logprobs": [
|
|
||||||
...
|
|
||||||
]
|
|
||||||
},
|
|
||||||
...
|
...
|
||||||
]
|
"probs": [
|
||||||
},
|
{
|
||||||
```
|
"id": <token id>,
|
||||||
|
"logprob": float,
|
||||||
|
"token": "<most likely token>",
|
||||||
|
"bytes": [int, int, ...],
|
||||||
|
"top_logprobs": [
|
||||||
|
{
|
||||||
|
"id": <token id>,
|
||||||
|
"logprob": float,
|
||||||
|
"token": "<token text>",
|
||||||
|
"bytes": [int, int, ...],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": <token id>,
|
||||||
|
"logprob": float,
|
||||||
|
"token": "<token text>",
|
||||||
|
"bytes": [int, int, ...],
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": <token id>,
|
||||||
|
"logprob": float,
|
||||||
|
"token": "<most likely token>",
|
||||||
|
"bytes": [int, int, ...],
|
||||||
|
"top_logprobs": [
|
||||||
|
...
|
||||||
|
]
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
},
|
||||||
|
```
|
||||||
|
Please note that if `post_sampling_probs` is set to `true`:
|
||||||
|
- `logprob` will be replace with `prob`, with the value between 0.0 and 1.0
|
||||||
|
- Returned number of probabilities may be less than `n_probs`
|
||||||
|
|
||||||
- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
|
- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
|
||||||
- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.
|
- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.
|
||||||
|
|
|
@ -93,6 +93,7 @@ struct slot_params {
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
bool timings_per_token = false;
|
bool timings_per_token = false;
|
||||||
|
bool post_sampling_probs = false;
|
||||||
bool ignore_eos = false;
|
bool ignore_eos = false;
|
||||||
|
|
||||||
struct common_params_sampling sampling;
|
struct common_params_sampling sampling;
|
||||||
|
@ -151,6 +152,7 @@ struct slot_params {
|
||||||
{"speculative.n_min", speculative.n_min},
|
{"speculative.n_min", speculative.n_min},
|
||||||
{"speculative.p_min", speculative.p_min},
|
{"speculative.p_min", speculative.p_min},
|
||||||
{"timings_per_token", timings_per_token},
|
{"timings_per_token", timings_per_token},
|
||||||
|
{"post_sampling_probs", post_sampling_probs},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -231,6 +233,7 @@ struct server_task {
|
||||||
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
||||||
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
||||||
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||||
|
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
||||||
|
|
||||||
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||||
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
||||||
|
@ -449,7 +452,7 @@ struct completion_token_output {
|
||||||
};
|
};
|
||||||
std::vector<token_prob> probs;
|
std::vector<token_prob> probs;
|
||||||
|
|
||||||
json to_json() const {
|
json to_json(bool post_sampling_probs) const {
|
||||||
json probs_for_token = json::array();
|
json probs_for_token = json::array();
|
||||||
for (const auto & p : probs) {
|
for (const auto & p : probs) {
|
||||||
std::string tok_str(p.tok_str);
|
std::string tok_str(p.tok_str);
|
||||||
|
@ -458,13 +461,16 @@ struct completion_token_output {
|
||||||
{"id", p.tok},
|
{"id", p.tok},
|
||||||
{"token", tok_str},
|
{"token", tok_str},
|
||||||
{"bytes", str_to_bytes(p.tok_str)},
|
{"bytes", str_to_bytes(p.tok_str)},
|
||||||
{"logprob", logarithm(p.prob)},
|
{
|
||||||
|
post_sampling_probs ? "prob" : "logprob",
|
||||||
|
post_sampling_probs ? p.prob : logarithm(p.prob)
|
||||||
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
return probs_for_token;
|
return probs_for_token;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
|
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
|
||||||
json out = json::array();
|
json out = json::array();
|
||||||
for (const auto & it : probs) {
|
for (const auto & it : probs) {
|
||||||
std::string tok_str(it.text_to_send);
|
std::string tok_str(it.text_to_send);
|
||||||
|
@ -472,9 +478,12 @@ struct completion_token_output {
|
||||||
out.push_back(json {
|
out.push_back(json {
|
||||||
{"id", it.tok},
|
{"id", it.tok},
|
||||||
{"token", tok_str},
|
{"token", tok_str},
|
||||||
{"logprob", logarithm(it.prob)},
|
|
||||||
{"bytes", str_to_bytes(it.text_to_send)},
|
{"bytes", str_to_bytes(it.text_to_send)},
|
||||||
{"top_logprobs", it.to_json()},
|
{"top_logprobs", it.to_json(post_sampling_probs)},
|
||||||
|
{
|
||||||
|
post_sampling_probs ? "prob" : "logprob",
|
||||||
|
post_sampling_probs ? it.prob : logarithm(it.prob)
|
||||||
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
|
@ -512,6 +521,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
stop_type stop = STOP_TYPE_NONE;
|
stop_type stop = STOP_TYPE_NONE;
|
||||||
|
|
||||||
|
bool post_sampling_probs;
|
||||||
std::vector<completion_token_output> probs_output;
|
std::vector<completion_token_output> probs_output;
|
||||||
|
|
||||||
slot_params generation_params;
|
slot_params generation_params;
|
||||||
|
@ -557,7 +567,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
{"timings", timings.to_json()},
|
{"timings", timings.to_json()},
|
||||||
};
|
};
|
||||||
if (!stream && !probs_output.empty()) {
|
if (!stream && !probs_output.empty()) {
|
||||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
@ -579,7 +589,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
|
|
||||||
if (!stream && probs_output.size() > 0) {
|
if (!stream && probs_output.size() > 0) {
|
||||||
choice["logprobs"] = json{
|
choice["logprobs"] = json{
|
||||||
{"content", completion_token_output::probs_vector_to_json(probs_output)},
|
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -652,6 +662,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
int32_t n_decoded;
|
int32_t n_decoded;
|
||||||
int32_t n_prompt_tokens;
|
int32_t n_prompt_tokens;
|
||||||
|
|
||||||
|
bool post_sampling_probs;
|
||||||
completion_token_output prob_output;
|
completion_token_output prob_output;
|
||||||
result_timings timings;
|
result_timings timings;
|
||||||
|
|
||||||
|
@ -690,7 +701,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
res.push_back({"timings", timings.to_json()});
|
res.push_back({"timings", timings.to_json()});
|
||||||
}
|
}
|
||||||
if (!prob_output.probs.empty()) {
|
if (!prob_output.probs.empty()) {
|
||||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output});
|
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
@ -746,7 +757,7 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
|
|
||||||
if (prob_output.probs.size() > 0) {
|
if (prob_output.probs.size() > 0) {
|
||||||
choices[0]["logprobs"] = json{
|
choices[0]["logprobs"] = json{
|
||||||
{"content", completion_token_output::probs_vector_to_json({prob_output})},
|
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1944,28 +1955,53 @@ struct server_context {
|
||||||
return slot.has_next_token; // continue
|
return slot.has_next_token; // continue
|
||||||
}
|
}
|
||||||
|
|
||||||
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
|
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
|
||||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
|
||||||
int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
||||||
size_t n_probs = slot.params.sampling.n_probs;
|
size_t n_probs = slot.params.sampling.n_probs;
|
||||||
|
int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
if (post_sampling) {
|
||||||
|
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||||
|
|
||||||
bool found_sampled_tok = false;
|
bool found_sampled_tok = false;
|
||||||
result.probs.reserve(n_probs);
|
result.probs.reserve(n_probs);
|
||||||
for (int i = 0; i < n_vocab; i++) {
|
for (int i = 0; i < n_vocab; i++) {
|
||||||
// set probability for sampled token
|
// set probability for sampled token
|
||||||
if (cur[i].id == result.tok) {
|
if (cur[i].id == result.tok) {
|
||||||
found_sampled_tok = true;
|
found_sampled_tok = true;
|
||||||
result.prob = cur[i].p;
|
result.prob = cur[i].p;
|
||||||
|
}
|
||||||
|
// set probability for top n_probs tokens
|
||||||
|
result.probs.push_back({
|
||||||
|
cur[i].id,
|
||||||
|
common_detokenize(ctx, {cur[i].id}, special),
|
||||||
|
cur[i].p
|
||||||
|
});
|
||||||
|
// break if we have all the necessary data
|
||||||
|
if (result.probs.size() == n_probs && found_sampled_tok) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// set probability for top n_probs tokens
|
} else {
|
||||||
result.probs.push_back({
|
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
||||||
cur[i].id,
|
const size_t max_probs = cur_p->size;
|
||||||
common_detokenize(ctx, {cur[i].id}, special),
|
|
||||||
cur[i].p
|
bool found_sampled_tok = false;
|
||||||
});
|
result.probs.reserve(max_probs);
|
||||||
// break if we have all the necessary data
|
for (size_t i = 0; i < max_probs; i++) {
|
||||||
if (result.probs.size() == n_probs && found_sampled_tok) {
|
// set probability for sampled token
|
||||||
break;
|
if (cur_p->data[i].id == result.tok) {
|
||||||
|
found_sampled_tok = true;
|
||||||
|
result.prob = cur_p->data[i].p;
|
||||||
|
}
|
||||||
|
// set probability for top n_probs tokens
|
||||||
|
result.probs.push_back({
|
||||||
|
cur_p->data[i].id,
|
||||||
|
common_detokenize(ctx, {cur_p->data[i].id}, special),
|
||||||
|
cur_p->data[i].p
|
||||||
|
});
|
||||||
|
// break if we have all the necessary data
|
||||||
|
if (result.probs.size() == n_probs && found_sampled_tok) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1997,8 +2033,9 @@ struct server_context {
|
||||||
res->content = tkn.text_to_send;
|
res->content = tkn.text_to_send;
|
||||||
res->tokens = { tkn.tok };
|
res->tokens = { tkn.tok };
|
||||||
|
|
||||||
res->n_decoded = slot.n_decoded;
|
res->n_decoded = slot.n_decoded;
|
||||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||||
|
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||||
|
|
||||||
res->verbose = slot.params.verbose;
|
res->verbose = slot.params.verbose;
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
|
@ -2030,13 +2067,14 @@ struct server_context {
|
||||||
res->timings = slot.get_timings();
|
res->timings = slot.get_timings();
|
||||||
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
||||||
|
|
||||||
res->truncated = slot.truncated;
|
res->truncated = slot.truncated;
|
||||||
res->n_decoded = slot.n_decoded;
|
res->n_decoded = slot.n_decoded;
|
||||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||||
res->n_tokens_cached = slot.n_past;
|
res->n_tokens_cached = slot.n_past;
|
||||||
res->has_new_line = slot.has_new_line;
|
res->has_new_line = slot.has_new_line;
|
||||||
res->stopping_word = slot.stopping_word;
|
res->stopping_word = slot.stopping_word;
|
||||||
res->stop = slot.stop;
|
res->stop = slot.stop;
|
||||||
|
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||||
|
|
||||||
res->verbose = slot.params.verbose;
|
res->verbose = slot.params.verbose;
|
||||||
res->stream = slot.params.stream;
|
res->stream = slot.params.stream;
|
||||||
|
@ -2859,7 +2897,7 @@ struct server_context {
|
||||||
result.prob = 1.0f; // set later
|
result.prob = 1.0f; // set later
|
||||||
|
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
populate_token_probs(slot, result, params_base.special, tok_idx);
|
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!process_token(result, slot)) {
|
if (!process_token(result, slot)) {
|
||||||
|
|
|
@ -309,3 +309,30 @@ def test_n_probs_stream():
|
||||||
assert "token" in prob and type(prob["token"]) == str
|
assert "token" in prob and type(prob["token"]) == str
|
||||||
assert "logprob" in prob and prob["logprob"] <= 0.0
|
assert "logprob" in prob and prob["logprob"] <= 0.0
|
||||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||||
|
|
||||||
|
|
||||||
|
def test_n_probs_post_sampling():
|
||||||
|
global server
|
||||||
|
server.multi_token_probs = True
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
"n_probs": 10,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"n_predict": 5,
|
||||||
|
"post_sampling_probs": True,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "completion_probabilities" in res.body
|
||||||
|
assert len(res.body["completion_probabilities"]) == 5
|
||||||
|
for tok in res.body["completion_probabilities"]:
|
||||||
|
assert "id" in tok and tok["id"] > 0
|
||||||
|
assert "token" in tok and type(tok["token"]) == str
|
||||||
|
assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0
|
||||||
|
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||||
|
assert len(tok["top_logprobs"]) == 10
|
||||||
|
for prob in tok["top_logprobs"]:
|
||||||
|
assert "id" in prob and prob["id"] > 0
|
||||||
|
assert "token" in prob and type(prob["token"]) == str
|
||||||
|
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
||||||
|
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||||
|
|
|
@ -50,6 +50,8 @@ def test_embedding_multiple():
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input,is_multi_prompt",
|
"input,is_multi_prompt",
|
||||||
[
|
[
|
||||||
|
# do not crash on empty input
|
||||||
|
("", False),
|
||||||
# single prompt
|
# single prompt
|
||||||
("string", False),
|
("string", False),
|
||||||
([12, 34, 56], False),
|
([12, 34, 56], False),
|
||||||
|
@ -103,6 +105,7 @@ def test_embedding_pooling_none_oai():
|
||||||
|
|
||||||
# /v1/embeddings does not support pooling type 'none'
|
# /v1/embeddings does not support pooling type 'none'
|
||||||
assert res.status_code == 400
|
assert res.status_code == 400
|
||||||
|
assert "error" in res.body
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_openai_library_single():
|
def test_embedding_openai_library_single():
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue