return pre-sampling p
This commit is contained in:
parent
01afafef93
commit
cc90cdbc33
2 changed files with 43 additions and 19 deletions
|
@ -1886,25 +1886,17 @@ struct server_context {
|
||||||
return slot.has_next_token; // continue
|
return slot.has_next_token; // continue
|
||||||
}
|
}
|
||||||
|
|
||||||
void populate_token_probs(const server_slot & slot, completion_token_output & result) {
|
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) {
|
||||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
||||||
const size_t max_probs = cur_p->size;
|
int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
|
||||||
// set prob for the sampled token
|
// only take at most n_probs tokens
|
||||||
for (size_t i = 0; i < max_probs; ++i) {
|
const int n_probs = slot.params.sampling.n_probs;
|
||||||
if (result.tok == cur_p->data[i].id) {
|
for (int i = 0; i < std::min(n_probs, n_vocab); i++) {
|
||||||
result.prob = cur_p->data[i].p;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// set probs for the top n tokens
|
|
||||||
for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) {
|
|
||||||
auto tok_id = cur_p->data[i].id;
|
|
||||||
result.probs.push_back({
|
result.probs.push_back({
|
||||||
tok_id,
|
cur[i].id,
|
||||||
tokens_to_output_formatted_string(ctx, tok_id),
|
common_detokenize(ctx, {cur[i].id}, special),
|
||||||
cur_p->data[i].p,
|
cur[i].p
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2758,7 +2750,9 @@ struct server_context {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
const int tok_idx = slot.i_batch - i;
|
||||||
|
|
||||||
|
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||||
|
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
|
|
||||||
|
@ -2782,7 +2776,7 @@ struct server_context {
|
||||||
result.prob = 1.0f; // set later
|
result.prob = 1.0f; // set later
|
||||||
|
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
populate_token_probs(slot, result);
|
populate_token_probs(slot, result, params_base.special, tok_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!process_token(result, slot)) {
|
if (!process_token(result, slot)) {
|
||||||
|
|
|
@ -694,3 +694,33 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
|
||||||
static std::string safe_json_to_str(json data) {
|
static std::string safe_json_to_str(json data) {
|
||||||
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
return data.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
||||||
|
std::vector<llama_token_data> cur;
|
||||||
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
|
||||||
|
cur.resize(n_vocab);
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply softmax
|
||||||
|
float max_l = cur[0].logit;
|
||||||
|
float cum_sum = 0.0f;
|
||||||
|
for (size_t i = 0; i < cur.size(); ++i) {
|
||||||
|
float p = expf(cur[i].logit - max_l);
|
||||||
|
cur[i].p = p;
|
||||||
|
cum_sum += p;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < cur.size(); ++i) {
|
||||||
|
cur[i].p /= cum_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sort tokens by probability
|
||||||
|
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||||
|
return a.p > b.p;
|
||||||
|
});
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue