server : fix logprobs, make it OAI-compatible (#10783)

* server : fix logprobs, make it openai-compatible

* update docs

* add std::log

* return pre-sampling p

* sort before apply softmax

* add comment

* fix test

* set p for sampled token

* update docs

* add --multi-token-probs

* update docs

* add `post_sampling_probs` option

* update docs [no ci]

* remove --multi-token-probs

* "top_probs" with "post_sampling_probs"

* resolve review comments

* rename struct token_prob to prob_info

* correct comment placement

* fix setting prob for sampled token
This commit is contained in:
Xuan Son Nguyen 2024-12-19 15:40:08 +01:00 committed by GitHub
parent a3c33b1dce
commit 57bb2c40cd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 396 additions and 107 deletions

View file

@ -171,6 +171,36 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
return result;
}
// return the last index of character that can form a valid string
// if the last character is potentially cut in half, return the index before the cut
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
static size_t validate_utf8(const std::string& text) {
size_t len = text.size();
if (len == 0) return 0;
// Check the last few bytes to see if a multi-byte character is cut off
for (size_t i = 1; i <= 4 && i <= len; ++i) {
unsigned char c = text[len - i];
// Check for start of a multi-byte sequence from the end
if ((c & 0xE0) == 0xC0) {
// 2-byte character start: 110xxxxx
// Needs at least 2 bytes
if (i < 2) return len - i;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character start: 1110xxxx
// Needs at least 3 bytes
if (i < 3) return len - i;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character start: 11110xxx
// Needs at least 4 bytes
if (i < 4) return len - i;
}
}
// If no cut-off multi-byte character is found, return full length
return len;
}
//
// template utils
//
@ -671,3 +701,33 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
static std::string safe_json_to_str(json data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace);
}
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
std::vector<llama_token_data> cur;
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
// sort tokens by logits
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
// apply softmax
float max_l = cur[0].logit;
float cum_sum = 0.0f;
for (size_t i = 0; i < cur.size(); ++i) {
float p = expf(cur[i].logit - max_l);
cur[i].p = p;
cum_sum += p;
}
for (size_t i = 0; i < cur.size(); ++i) {
cur[i].p /= cum_sum;
}
return cur;
}