return pre-sampling p
This commit is contained in:
parent
01afafef93
commit
cc90cdbc33
2 changed files with 43 additions and 19 deletions
|
@ -1886,25 +1886,17 @@ struct server_context {
|
|||
return slot.has_next_token; // continue
|
||||
}
|
||||
|
||||
void populate_token_probs(const server_slot & slot, completion_token_output & result) {
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
||||
const size_t max_probs = cur_p->size;
|
||||
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
|
||||
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||
int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
// set prob for the sampled token
|
||||
for (size_t i = 0; i < max_probs; ++i) {
|
||||
if (result.tok == cur_p->data[i].id) {
|
||||
result.prob = cur_p->data[i].p;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// set probs for the top n tokens
|
||||
for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) {
|
||||
auto tok_id = cur_p->data[i].id;
|
||||
// only take at most n_probs tokens
|
||||
const int n_probs = slot.params.sampling.n_probs;
|
||||
for (int i = 0; i < std::min(n_probs, n_vocab); i++) {
|
||||
result.probs.push_back({
|
||||
tok_id,
|
||||
tokens_to_output_formatted_string(ctx, tok_id),
|
||||
cur_p->data[i].p,
|
||||
cur[i].id,
|
||||
common_detokenize(ctx, {cur[i].id}, special),
|
||||
cur[i].p
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -2758,7 +2750,9 @@ struct server_context {
|
|||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
|
@ -2782,7 +2776,7 @@ struct server_context {
|
|||
result.prob = 1.0f; // set later
|
||||
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
populate_token_probs(slot, result);
|
||||
populate_token_probs(slot, result, params_base.special, tok_idx);
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
|
|
|
@ -694,3 +694,33 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
|
|||
static std::string safe_json_to_str(json data) {
|
||||
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
}
|
||||
|
||||
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
||||
std::vector<llama_token_data> cur;
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
|
||||
cur.resize(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
|
||||
// apply softmax
|
||||
float max_l = cur[0].logit;
|
||||
float cum_sum = 0.0f;
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
float p = expf(cur[i].logit - max_l);
|
||||
cur[i].p = p;
|
||||
cum_sum += p;
|
||||
}
|
||||
for (size_t i = 0; i < cur.size(); ++i) {
|
||||
cur[i].p /= cum_sum;
|
||||
}
|
||||
|
||||
// sort tokens by probability
|
||||
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.p > b.p;
|
||||
});
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue