server: handle probs output when temp=0; handle final response probs output

This commit is contained in:
Wang Haoran(Robin) 2023-06-25 16:29:34 -07:00
parent e815b69579
commit c9e6642cf7

View file

@ -357,6 +357,9 @@ struct llama_server_context {
if (temp <= 0) { if (temp <= 0) {
// Greedy sampling // Greedy sampling
result.tok = llama_sample_token_greedy(ctx, &candidates_p); result.tok = llama_sample_token_greedy(ctx, &candidates_p);
if (n_probs > 0) {
llama_sample_softmax(ctx, &candidates_p);
}
} else { } else {
if (mirostat == 1) { if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau; static float mirostat_mu = 2.0f * mirostat_tau;
@ -369,10 +372,11 @@ struct llama_server_context {
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else { } else {
// Temperature sampling // Temperature sampling
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); size_t min_keep = std::max(1, n_probs);
llama_sample_typical(ctx, &candidates_p, typical_p, 1); llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
llama_sample_top_p(ctx, &candidates_p, top_p, 1); llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
llama_sample_top_k(ctx, &candidates_p, top_k, 1); llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
llama_sample_top_k(ctx, &candidates_p, top_k, min_keep);
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &candidates_p, temp);
result.tok = llama_sample_token(ctx, &candidates_p); result.tok = llama_sample_token(ctx, &candidates_p);
} }
@ -747,8 +751,7 @@ static json format_final_response(llama_server_context & llama, const std::strin
}; };
if (llama.params.n_probs > 0) { if (llama.params.n_probs > 0) {
json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs); res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
res["completion_probabilities"] = completion_probabilities_json;
} }
return res; return res;
@ -761,8 +764,7 @@ static json format_partial_response(llama_server_context & llama, const std::str
}; };
if (llama.params.n_probs > 0) { if (llama.params.n_probs > 0) {
json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs); res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
res["completion_probabilities"] = completion_probabilities_json;
} }
return res; return res;
@ -964,7 +966,7 @@ int main(int argc, char ** argv) {
const json data = llama.has_next_token const json data = llama.has_next_token
? format_partial_response(llama, to_send, probs_output) ? format_partial_response(llama, to_send, probs_output)
// Generation is done, send extra information. // Generation is done, send extra information.
: format_final_response(llama, to_send, probs_output); : format_final_response(llama, to_send, llama.generated_token_probs);
const std::string str = const std::string str =
"data: " + "data: " +