return pre-sampling p

This commit is contained in:
Xuan Son Nguyen 2024-12-12 13:44:30 +01:00
parent 01afafef93
commit cc90cdbc33
2 changed files with 43 additions and 19 deletions

View file

@ -1886,25 +1886,17 @@ struct server_context {
return slot.has_next_token; // continue
}
void populate_token_probs(const server_slot & slot, completion_token_output & result) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size;
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
int n_vocab = llama_n_vocab(llama_get_model(ctx));
// set prob for the sampled token
for (size_t i = 0; i < max_probs; ++i) {
if (result.tok == cur_p->data[i].id) {
result.prob = cur_p->data[i].p;
break;
}
}
// set probs for the top n tokens
for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) {
auto tok_id = cur_p->data[i].id;
// only take at most n_probs tokens
const int n_probs = slot.params.sampling.n_probs;
for (int i = 0; i < std::min(n_probs, n_vocab); i++) {
result.probs.push_back({
tok_id,
tokens_to_output_formatted_string(ctx, tok_id),
cur_p->data[i].p,
cur[i].id,
common_detokenize(ctx, {cur[i].id}, special),
cur[i].p
});
}
}
@ -2758,7 +2750,9 @@ struct server_context {
continue; // continue loop of slots
}
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
const int tok_idx = slot.i_batch - i;
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
slot.i_batch = -1;
@ -2782,7 +2776,7 @@ struct server_context {
result.prob = 1.0f; // set later
if (slot.params.sampling.n_probs > 0) {
populate_token_probs(slot, result);
populate_token_probs(slot, result, params_base.special, tok_idx);
}
if (!process_token(result, slot)) {

View file

@ -694,3 +694,33 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
static std::string safe_json_to_str(json data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace);
}
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
std::vector<llama_token_data> cur;
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
// apply softmax
float max_l = cur[0].logit;
float cum_sum = 0.0f;
for (size_t i = 0; i < cur.size(); ++i) {
float p = expf(cur[i].logit - max_l);
cur[i].p = p;
cum_sum += p;
}
for (size_t i = 0; i < cur.size(); ++i) {
cur[i].p /= cum_sum;
}
// sort tokens by probability
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
return a.p > b.p;
});
return cur;
}