"top_probs" with "post_sampling_probs"

This commit is contained in:
Xuan Son Nguyen 2024-12-18 17:27:29 +01:00
parent 8734df73d9
commit fd4cf34b00
3 changed files with 12 additions and 4 deletions

View file

@ -497,7 +497,12 @@ These words will not be included in the completion, so make sure to add them to
```
Please note that if `post_sampling_probs` is set to `true`:
- `logprob` will be replace with `prob`, with the value between 0.0 and 1.0
- Returned number of probabilities may be less than `n_probs`
- `top_logprobs` will be replace with `top_probs`. Each element inside contains:
- `id`: token ID
- `token`: token in string
- `bytes`: token in bytes
- `prob`: token probability, with the value between 0.0 and 1.0
- Number of elements in `top_probs` may be less than `n_probs`
- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.

View file

@ -475,7 +475,10 @@ struct completion_token_output {
{"id", it.tok},
{"token", tok_str},
{"bytes", str_to_bytes(it.text_to_send)},
{"top_logprobs", it.to_json(post_sampling_probs)},
{
post_sampling_probs ? "top_probs" : "top_logprobs",
it.to_json(post_sampling_probs)
},
{
post_sampling_probs ? "prob" : "logprob",
post_sampling_probs ? it.prob : logarithm(it.prob)

View file

@ -327,8 +327,8 @@ def test_n_probs_post_sampling():
assert "token" in tok and type(tok["token"]) == str
assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_logprobs"]) == 10
for prob in tok["top_logprobs"]:
assert len(tok["top_probs"]) == 10
for prob in tok["top_probs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0