From c9e6642cf7eae4223e96f224a6c50d0712f6a2c2 Mon Sep 17 00:00:00 2001 From: "Wang Haoran(Robin)" Date: Sun, 25 Jun 2023 16:29:34 -0700 Subject: [PATCH] server: handle probs output when temp=0; handle final response probs output --- examples/server/server.cpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3d076060c..e349489f8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -357,6 +357,9 @@ struct llama_server_context { if (temp <= 0) { // Greedy sampling result.tok = llama_sample_token_greedy(ctx, &candidates_p); + if (n_probs > 0) { + llama_sample_softmax(ctx, &candidates_p); + } } else { if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; @@ -369,10 +372,11 @@ struct llama_server_context { result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling - llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); - llama_sample_typical(ctx, &candidates_p, typical_p, 1); - llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_top_k(ctx, &candidates_p, top_k, 1); + size_t min_keep = std::max(1, n_probs); + llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep); + llama_sample_typical(ctx, &candidates_p, typical_p, min_keep); + llama_sample_top_p(ctx, &candidates_p, top_p, min_keep); + llama_sample_top_k(ctx, &candidates_p, top_k, min_keep); llama_sample_temperature(ctx, &candidates_p, temp); result.tok = llama_sample_token(ctx, &candidates_p); } @@ -747,8 +751,7 @@ static json format_final_response(llama_server_context & llama, const std::strin }; if (llama.params.n_probs > 0) { - json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs); - res["completion_probabilities"] = completion_probabilities_json; + res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } return res; @@ -761,8 +764,7 @@ static json format_partial_response(llama_server_context & llama, const std::str }; if (llama.params.n_probs > 0) { - json completion_probabilities_json = probs_vector_to_json(llama.ctx, probs); - res["completion_probabilities"] = completion_probabilities_json; + res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs); } return res; @@ -964,7 +966,7 @@ int main(int argc, char ** argv) { const json data = llama.has_next_token ? format_partial_response(llama, to_send, probs_output) // Generation is done, send extra information. - : format_final_response(llama, to_send, probs_output); + : format_final_response(llama, to_send, llama.generated_token_probs); const std::string str = "data: " +