diff --git a/examples/common.h b/examples/common.h index 6c2953cb2..6393c8563 100644 --- a/examples/common.h +++ b/examples/common.h @@ -31,6 +31,7 @@ struct gpt_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance + int32_t n_probs = 0; // if greater than 1, output the probabilities of top n_probs tokens. Max 5 // sampling parameters std::unordered_map logit_bias; // logit bias for specific tokens diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c0984aadb..16f5bac0a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -26,6 +26,28 @@ struct server_params { int32_t write_timeout = 600; }; +// completion string output with probabilities +struct completion_string_output { + struct token_prob { + std::string tok_str; + float prob; + }; + + std::vector probs; + std::string tok_str; +}; + +// completion token output with probabilities +struct completion_token_output { + struct token_prob { + llama_token tok; + float prob; + }; + + std::vector probs; + llama_token tok; +}; + static size_t common_part(const std::vector & a, const std::vector & b) { size_t i; for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} @@ -107,6 +129,7 @@ struct llama_server_context { bool stream = false; bool has_next_token = false; std::string generated_text; + std::vector generated_text_probs; size_t num_tokens_predicted = 0; size_t n_past = 0; @@ -137,6 +160,7 @@ struct llama_server_context { num_tokens_predicted = 0; generated_text = ""; generated_text.reserve(params.n_ctx); + generated_text_probs.clear(); truncated = false; stopped_eos = false; stopped_word = false; @@ -216,8 +240,9 @@ struct llama_server_context { llama_set_rng_seed(ctx, params.seed); } - llama_token nextToken() { - llama_token result = -1; + completion_token_output nextToken() { + completion_token_output result; + result.tok = -1; if (embd.size() >= (size_t)params.n_ctx) { // Reset context @@ -256,7 +281,8 @@ struct llama_server_context { if (params.n_predict == 0) { has_next_token = false; - return llama_token_eos(); + result.tok = llama_token_eos(); + return result; } // out of user input, sample next token @@ -273,7 +299,7 @@ struct llama_server_context { const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; const bool penalize_nl = params.penalize_nl; - llama_token id = 0; + const int32_t n_probs = params.n_probs; { auto * logits = llama_get_logits(ctx); @@ -307,17 +333,17 @@ struct llama_server_context { if (temp <= 0) { // Greedy sampling - id = llama_sample_token_greedy(ctx, &candidates_p); + result.tok = llama_sample_token_greedy(ctx, &candidates_p); } else { if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); @@ -325,17 +351,19 @@ struct llama_server_context { llama_sample_top_p(ctx, &candidates_p, top_p, 1); llama_sample_top_k(ctx, &candidates_p, top_k, 1); llama_sample_temperature(ctx, &candidates_p, temp); - id = llama_sample_token(ctx, &candidates_p); + result.tok = llama_sample_token(ctx, &candidates_p); } } + for (size_t i = 0; i < std::min(candidates_p.size, std::min((size_t) n_probs, size_t(5))); ++i) { + result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); + } last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(id); + last_n_tokens.push_back(result.tok); num_tokens_predicted++; } // add it to the context - embd.push_back(id); - result = id; + embd.push_back(result.tok); // decrement remaining sampling budget --n_remain; @@ -377,12 +405,22 @@ struct llama_server_context { return stop_pos; } - std::string doCompletion() { - const llama_token token = nextToken(); + completion_string_output doCompletion() { + const completion_token_output token_with_probs = nextToken(); + completion_string_output result; - const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token); + const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok); + result.tok_str = token_text; generated_text += token_text; + // iterate through token_with_probs.probs, if tok is valid, convert it to string and add to result.prob + for (const auto & prob : token_with_probs.probs) { + const std::string prob_text = prob.tok == -1 ? "" : llama_token_to_str(ctx, prob.tok); + result.probs.push_back({prob_text, prob.prob}); + } + + generated_text_probs.push_back(result); + if (multibyte_pending > 0) { multibyte_pending -= token_text.size(); } else if (token_text.size() == 1) { @@ -411,8 +449,8 @@ struct llama_server_context { } LOG_VERBOSE("next token", { - { "token", token }, - { "token_text", llama_token_to_str(ctx, token) }, + { "token", token_with_probs.tok }, + { "token_text", llama_token_to_str(ctx, token_with_probs.tok) }, { "has_next_token", has_next_token }, { "n_remain", n_remain }, { "num_tokens_predicted", num_tokens_predicted }, @@ -422,7 +460,7 @@ struct llama_server_context { { "stopping_word", stopping_word }, }); - return token_text; + return result; } std::vector getEmbedding() { @@ -664,6 +702,7 @@ static json format_generation_settings(llama_server_context & llama) { { "ignore_eos", ignore_eos }, { "stream", llama.stream }, { "logit_bias", llama.params.logit_bias }, + { "n_probs", llama.params.n_probs }, }; } @@ -673,9 +712,26 @@ static json format_embedding_response(llama_server_context & llama) { }; } -static json format_final_response(llama_server_context & llama, const std::string & content) { +static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector & probs) { + + json completion_probabilities_json = json::array(); + for (const auto & prob : probs) { + json probs_for_token = json::array(); + for (const auto & p : prob.probs) { + probs_for_token.push_back(json { + { "tok_str", p.tok_str }, + { "prob", p.prob }, + }); + } + completion_probabilities_json.push_back(json { + {"content", prob.tok_str}, + {"probs", probs_for_token}, + }); + } + return json { { "content", content }, + { "completion_probabilities", completion_probabilities_json}, { "stop", true }, { "model", llama.params.model_alias }, { "tokens_predicted", llama.num_tokens_predicted }, @@ -689,11 +745,25 @@ static json format_final_response(llama_server_context & llama, const std::strin }; } -static json format_partial_response(const std::string & content) { - return json { +static json format_partial_response(const std::string & content, const completion_string_output & probs) { + json res = json { { "content", content }, { "stop", false }, }; + + // iterate through probs.probs, and add to res + json probs_json = json::array(); + for (const auto & prob : probs.probs) { + probs_json.push_back(json { + { "tok_str", prob.tok_str }, + { "prob", prob.prob }, + }); + } + if (probs.probs.size() > 0) { + res["probs"] = probs_json; + } + + return res; } static json format_tokenizer_response(const std::vector & tokens) { @@ -723,6 +793,7 @@ static void parse_options_completion(const json & body, llama_server_context & l llama.params.n_keep = body.value("n_keep", default_params.n_keep); llama.params.seed = body.value("seed", default_params.seed); llama.params.prompt = body.value("prompt", default_params.prompt); + llama.params.n_probs = body.value("n_probs", default_params.n_probs); llama.params.logit_bias.clear(); if (body.value("ignore_eos", false)) { @@ -825,7 +896,8 @@ int main(int argc, char ** argv) { size_t stop_pos = std::string::npos; while (llama.has_next_token) { - const std::string token_text = llama.doCompletion(); + const completion_string_output token_text_with_probs = llama.doCompletion(); + const std::string token_text = token_text_with_probs.tok_str; stop_pos = llama.findStoppingStrings(llama.generated_text, token_text.size(), STOP_FULL); @@ -839,7 +911,7 @@ int main(int argc, char ** argv) { llama.generated_text.end()); } - const json data = format_final_response(llama, llama.generated_text); + const json data = format_final_response(llama, llama.generated_text, llama.generated_text_probs); llama_print_timings(llama.ctx); @@ -850,7 +922,7 @@ int main(int argc, char ** argv) { size_t sent_count = 0; while (llama.has_next_token) { - const std::string token_text = llama.doCompletion(); + const completion_string_output token_text_with_probs = llama.doCompletion(); if (llama.multibyte_pending > 0) { continue; } @@ -859,14 +931,14 @@ int main(int argc, char ** argv) { const std::string str_test = llama.generated_text.substr(pos); size_t stop_pos = - llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); + llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(), STOP_FULL); if (stop_pos != std::string::npos) { llama.generated_text.erase( llama.generated_text.begin() + pos + stop_pos, llama.generated_text.end()); pos = std::min(sent_count, llama.generated_text.size()); } else { - stop_pos = llama.findStoppingStrings(str_test, token_text.size(), + stop_pos = llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(), STOP_PARTIAL); } @@ -874,9 +946,9 @@ int main(int argc, char ** argv) { sent_count += to_send.size(); const json data = llama.has_next_token - ? format_partial_response(to_send) + ? format_partial_response(to_send, token_text_with_probs) // Generation is done, send extra information. - : format_final_response(llama, to_send); + : format_final_response(llama, to_send, {token_text_with_probs}); const std::string str = "data: " +