resolve review comments
This commit is contained in:
parent
fd4cf34b00
commit
d2463dc8df
3 changed files with 41 additions and 38 deletions
|
@ -496,8 +496,8 @@ These words will not be included in the completion, so make sure to add them to
|
||||||
},
|
},
|
||||||
```
|
```
|
||||||
Please note that if `post_sampling_probs` is set to `true`:
|
Please note that if `post_sampling_probs` is set to `true`:
|
||||||
- `logprob` will be replace with `prob`, with the value between 0.0 and 1.0
|
- `logprob` will be replaced with `prob`, with the value between 0.0 and 1.0
|
||||||
- `top_logprobs` will be replace with `top_probs`. Each element inside contains:
|
- `top_logprobs` will be replaced with `top_probs`. Each element contains:
|
||||||
- `id`: token ID
|
- `id`: token ID
|
||||||
- `token`: token in string
|
- `token`: token in string
|
||||||
- `bytes`: token in bytes
|
- `bytes`: token in bytes
|
||||||
|
|
|
@ -443,7 +443,7 @@ struct completion_token_output {
|
||||||
std::string text_to_send;
|
std::string text_to_send;
|
||||||
struct token_prob {
|
struct token_prob {
|
||||||
llama_token tok;
|
llama_token tok;
|
||||||
std::string tok_str;
|
std::string txt;
|
||||||
float prob;
|
float prob;
|
||||||
};
|
};
|
||||||
std::vector<token_prob> probs;
|
std::vector<token_prob> probs;
|
||||||
|
@ -451,12 +451,12 @@ struct completion_token_output {
|
||||||
json to_json(bool post_sampling_probs) const {
|
json to_json(bool post_sampling_probs) const {
|
||||||
json probs_for_token = json::array();
|
json probs_for_token = json::array();
|
||||||
for (const auto & p : probs) {
|
for (const auto & p : probs) {
|
||||||
std::string tok_str(p.tok_str);
|
std::string txt(p.txt);
|
||||||
tok_str.resize(validate_utf8(tok_str));
|
txt.resize(validate_utf8(txt));
|
||||||
probs_for_token.push_back(json {
|
probs_for_token.push_back(json {
|
||||||
{"id", p.tok},
|
{"id", p.tok},
|
||||||
{"token", tok_str},
|
{"token", txt},
|
||||||
{"bytes", str_to_bytes(p.tok_str)},
|
{"bytes", str_to_bytes(p.txt)},
|
||||||
{
|
{
|
||||||
post_sampling_probs ? "prob" : "logprob",
|
post_sampling_probs ? "prob" : "logprob",
|
||||||
post_sampling_probs ? p.prob : logarithm(p.prob)
|
post_sampling_probs ? p.prob : logarithm(p.prob)
|
||||||
|
@ -468,20 +468,20 @@ struct completion_token_output {
|
||||||
|
|
||||||
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
|
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
|
||||||
json out = json::array();
|
json out = json::array();
|
||||||
for (const auto & it : probs) {
|
for (const auto & p : probs) {
|
||||||
std::string tok_str(it.text_to_send);
|
std::string txt(p.text_to_send);
|
||||||
tok_str.resize(validate_utf8(tok_str));
|
txt.resize(validate_utf8(txt));
|
||||||
out.push_back(json {
|
out.push_back(json {
|
||||||
{"id", it.tok},
|
{"id", p.tok},
|
||||||
{"token", tok_str},
|
{"token", txt},
|
||||||
{"bytes", str_to_bytes(it.text_to_send)},
|
{"bytes", str_to_bytes(p.text_to_send)},
|
||||||
{
|
{
|
||||||
post_sampling_probs ? "top_probs" : "top_logprobs",
|
post_sampling_probs ? "top_probs" : "top_logprobs",
|
||||||
it.to_json(post_sampling_probs)
|
p.to_json(post_sampling_probs)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
post_sampling_probs ? "prob" : "logprob",
|
post_sampling_probs ? "prob" : "logprob",
|
||||||
post_sampling_probs ? it.prob : logarithm(it.prob)
|
post_sampling_probs ? p.prob : logarithm(p.prob)
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1958,28 +1958,7 @@ struct server_context {
|
||||||
size_t n_probs = slot.params.sampling.n_probs;
|
size_t n_probs = slot.params.sampling.n_probs;
|
||||||
int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
if (post_sampling) {
|
if (post_sampling) {
|
||||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
// TODO: optimize this with min-p optimization
|
||||||
|
|
||||||
bool found_sampled_tok = false;
|
|
||||||
result.probs.reserve(n_probs);
|
|
||||||
for (int i = 0; i < n_vocab; i++) {
|
|
||||||
// set probability for sampled token
|
|
||||||
if (cur[i].id == result.tok) {
|
|
||||||
found_sampled_tok = true;
|
|
||||||
result.prob = cur[i].p;
|
|
||||||
}
|
|
||||||
// set probability for top n_probs tokens
|
|
||||||
result.probs.push_back({
|
|
||||||
cur[i].id,
|
|
||||||
common_detokenize(ctx, {cur[i].id}, special),
|
|
||||||
cur[i].p
|
|
||||||
});
|
|
||||||
// break if we have all the necessary data
|
|
||||||
if (result.probs.size() == n_probs && found_sampled_tok) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
||||||
const size_t max_probs = cur_p->size;
|
const size_t max_probs = cur_p->size;
|
||||||
|
|
||||||
|
@ -2002,6 +1981,28 @@ struct server_context {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||||
|
|
||||||
|
bool found_sampled_tok = false;
|
||||||
|
result.probs.reserve(n_probs);
|
||||||
|
for (int i = 0; i < n_vocab; i++) {
|
||||||
|
// set probability for sampled token
|
||||||
|
if (cur[i].id == result.tok) {
|
||||||
|
found_sampled_tok = true;
|
||||||
|
result.prob = cur[i].p;
|
||||||
|
}
|
||||||
|
// set probability for top n_probs tokens
|
||||||
|
result.probs.push_back({
|
||||||
|
cur[i].id,
|
||||||
|
common_detokenize(ctx, {cur[i].id}, special),
|
||||||
|
cur[i].p
|
||||||
|
});
|
||||||
|
// break if we have all the necessary data
|
||||||
|
if (result.probs.size() == n_probs && found_sampled_tok) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -325,7 +325,7 @@ def test_n_probs_post_sampling():
|
||||||
for tok in res.body["completion_probabilities"]:
|
for tok in res.body["completion_probabilities"]:
|
||||||
assert "id" in tok and tok["id"] > 0
|
assert "id" in tok and tok["id"] > 0
|
||||||
assert "token" in tok and type(tok["token"]) == str
|
assert "token" in tok and type(tok["token"]) == str
|
||||||
assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0
|
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
|
||||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||||
assert len(tok["top_probs"]) == 10
|
assert len(tok["top_probs"]) == 10
|
||||||
for prob in tok["top_probs"]:
|
for prob in tok["top_probs"]:
|
||||||
|
@ -333,3 +333,5 @@ def test_n_probs_post_sampling():
|
||||||
assert "token" in prob and type(prob["token"]) == str
|
assert "token" in prob and type(prob["token"]) == str
|
||||||
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
||||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||||
|
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||||
|
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue