Merge pull request #7 from WangHaoranRobin/robin_fork_master
server: handle probs output when temp=0; handle final response probs output
This commit is contained in:
commit
77edee7d9b
1 changed files with 11 additions and 9 deletions
|
@ -362,6 +362,9 @@ struct llama_server_context {
|
||||||
if (temp <= 0) {
|
if (temp <= 0) {
|
||||||
// Greedy sampling
|
// Greedy sampling
|
||||||
result.tok = llama_sample_token_greedy(ctx, &candidates_p);
|
result.tok = llama_sample_token_greedy(ctx, &candidates_p);
|
||||||
|
if (n_probs > 0) {
|
||||||
|
llama_sample_softmax(ctx, &candidates_p);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (mirostat == 1) {
|
if (mirostat == 1) {
|
||||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
|
@ -374,10 +377,11 @@ struct llama_server_context {
|
||||||
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||||
} else {
|
} else {
|
||||||
// Temperature sampling
|
// Temperature sampling
|
||||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
size_t min_keep = std::max(1, n_probs);
|
||||||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
|
llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
|
||||||
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
|
llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
|
||||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
|
||||||
|
llama_sample_top_k(ctx, &candidates_p, top_k, min_keep);
|
||||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||||
result.tok = llama_sample_token(ctx, &candidates_p);
|
result.tok = llama_sample_token(ctx, &candidates_p);
|
||||||
}
|
}
|
||||||
|
@ -752,8 +756,7 @@ static json format_final_response(llama_server_context & llama, const std::strin
|
||||||
};
|
};
|
||||||
|
|
||||||
if (llama.params.n_probs > 0) {
|
if (llama.params.n_probs > 0) {
|
||||||
json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs);
|
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
||||||
res["completion_probabilities"] = completion_probabilities_json;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
@ -766,8 +769,7 @@ static json format_partial_response(llama_server_context & llama, const std::str
|
||||||
};
|
};
|
||||||
|
|
||||||
if (llama.params.n_probs > 0) {
|
if (llama.params.n_probs > 0) {
|
||||||
json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs);
|
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
|
||||||
res["completion_probabilities"] = completion_probabilities_json;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
@ -969,7 +971,7 @@ int main(int argc, char ** argv) {
|
||||||
const json data = llama.has_next_token
|
const json data = llama.has_next_token
|
||||||
? format_partial_response(llama, to_send, probs_output)
|
? format_partial_response(llama, to_send, probs_output)
|
||||||
// Generation is done, send extra information.
|
// Generation is done, send extra information.
|
||||||
: format_final_response(llama, to_send, probs_output);
|
: format_final_response(llama, to_send, llama.generated_token_probs);
|
||||||
|
|
||||||
const std::string str =
|
const std::string str =
|
||||||
"data: " +
|
"data: " +
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue