Merge pull request #1 from WangHaoranRobin/robin_fork_master

server: add option to output probabilities for completion
This commit is contained in:
WangHaoranRobin 2023-06-21 14:28:46 -07:00 committed by GitHub
commit 8004e673f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 100 additions and 27 deletions

View file

@ -31,6 +31,7 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance
int32_t n_probs = 0; // if greater than 1, output the probabilities of top n_probs tokens. Max 5
// sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

View file

@ -26,6 +26,28 @@ struct server_params {
int32_t write_timeout = 600;
};
// completion string output with probabilities
struct completion_string_output {
struct token_prob {
std::string tok_str;
float prob;
};
std::vector<token_prob> probs;
std::string tok_str;
};
// completion token output with probabilities
struct completion_token_output {
struct token_prob {
llama_token tok;
float prob;
};
std::vector<token_prob> probs;
llama_token tok;
};
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
@ -107,6 +129,7 @@ struct llama_server_context {
bool stream = false;
bool has_next_token = false;
std::string generated_text;
std::vector<completion_string_output> generated_text_probs;
size_t num_tokens_predicted = 0;
size_t n_past = 0;
@ -137,6 +160,7 @@ struct llama_server_context {
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
generated_text_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
@ -216,8 +240,9 @@ struct llama_server_context {
llama_set_rng_seed(ctx, params.seed);
}
llama_token nextToken() {
llama_token result = -1;
completion_token_output nextToken() {
completion_token_output result;
result.tok = -1;
if (embd.size() >= (size_t)params.n_ctx) {
// Reset context
@ -256,7 +281,8 @@ struct llama_server_context {
if (params.n_predict == 0) {
has_next_token = false;
return llama_token_eos();
result.tok = llama_token_eos();
return result;
}
// out of user input, sample next token
@ -273,7 +299,7 @@ struct llama_server_context {
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
const int32_t n_probs = params.n_probs;
{
auto * logits = llama_get_logits(ctx);
@ -307,17 +333,17 @@ struct llama_server_context {
if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
result.tok = llama_sample_token_greedy(ctx, &candidates_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
@ -325,17 +351,19 @@ struct llama_server_context {
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
result.tok = llama_sample_token(ctx, &candidates_p);
}
}
for (size_t i = 0; i < std::min(candidates_p.size, std::min((size_t) n_probs, size_t(5))); ++i) {
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
last_n_tokens.push_back(result.tok);
num_tokens_predicted++;
}
// add it to the context
embd.push_back(id);
result = id;
embd.push_back(result.tok);
// decrement remaining sampling budget
--n_remain;
@ -377,12 +405,22 @@ struct llama_server_context {
return stop_pos;
}
std::string doCompletion() {
const llama_token token = nextToken();
completion_string_output doCompletion() {
const completion_token_output token_with_probs = nextToken();
completion_string_output result;
const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token);
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
result.tok_str = token_text;
generated_text += token_text;
// iterate through token_with_probs.probs, if tok is valid, convert it to string and add to result.prob
for (const auto & prob : token_with_probs.probs) {
const std::string prob_text = prob.tok == -1 ? "" : llama_token_to_str(ctx, prob.tok);
result.probs.push_back({prob_text, prob.prob});
}
generated_text_probs.push_back(result);
if (multibyte_pending > 0) {
multibyte_pending -= token_text.size();
} else if (token_text.size() == 1) {
@ -411,8 +449,8 @@ struct llama_server_context {
}
LOG_VERBOSE("next token", {
{ "token", token },
{ "token_text", llama_token_to_str(ctx, token) },
{ "token", token_with_probs.tok },
{ "token_text", llama_token_to_str(ctx, token_with_probs.tok) },
{ "has_next_token", has_next_token },
{ "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted },
@ -422,7 +460,7 @@ struct llama_server_context {
{ "stopping_word", stopping_word },
});
return token_text;
return result;
}
std::vector<float> getEmbedding() {
@ -664,6 +702,7 @@ static json format_generation_settings(llama_server_context & llama) {
{ "ignore_eos", ignore_eos },
{ "stream", llama.stream },
{ "logit_bias", llama.params.logit_bias },
{ "n_probs", llama.params.n_probs },
};
}
@ -673,9 +712,26 @@ static json format_embedding_response(llama_server_context & llama) {
};
}
static json format_final_response(llama_server_context & llama, const std::string & content) {
static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_string_output> & probs) {
json completion_probabilities_json = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
probs_for_token.push_back(json {
{ "tok_str", p.tok_str },
{ "prob", p.prob },
});
}
completion_probabilities_json.push_back(json {
{"content", prob.tok_str},
{"probs", probs_for_token},
});
}
return json {
{ "content", content },
{ "completion_probabilities", completion_probabilities_json},
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
@ -689,11 +745,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
};
}
static json format_partial_response(const std::string & content) {
return json {
static json format_partial_response(const std::string & content, const completion_string_output & probs) {
json res = json {
{ "content", content },
{ "stop", false },
};
// iterate through probs.probs, and add to res
json probs_json = json::array();
for (const auto & prob : probs.probs) {
probs_json.push_back(json {
{ "tok_str", prob.tok_str },
{ "prob", prob.prob },
});
}
if (probs.probs.size() > 0) {
res["probs"] = probs_json;
}
return res;
}
static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
@ -723,6 +793,7 @@ static void parse_options_completion(const json & body, llama_server_context & l
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
llama.params.logit_bias.clear();
if (body.value("ignore_eos", false)) {
@ -825,7 +896,8 @@ int main(int argc, char ** argv) {
size_t stop_pos = std::string::npos;
while (llama.has_next_token) {
const std::string token_text = llama.doCompletion();
const completion_string_output token_text_with_probs = llama.doCompletion();
const std::string token_text = token_text_with_probs.tok_str;
stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL);
@ -839,7 +911,7 @@ int main(int argc, char ** argv) {
llama.generated_text.end());
}
const json data = format_final_response(llama, llama.generated_text);
const json data = format_final_response(llama, llama.generated_text, llama.generated_text_probs);
llama_print_timings(llama.ctx);
@ -850,7 +922,7 @@ int main(int argc, char ** argv) {
size_t sent_count = 0;
while (llama.has_next_token) {
const std::string token_text = llama.doCompletion();
const completion_string_output token_text_with_probs = llama.doCompletion();
if (llama.multibyte_pending > 0) {
continue;
}
@ -859,14 +931,14 @@ int main(int argc, char ** argv) {
const std::string str_test = llama.generated_text.substr(pos);
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
stop_pos = llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(),
STOP_PARTIAL);
}
@ -874,9 +946,9 @@ int main(int argc, char ** argv) {
sent_count += to_send.size();
const json data = llama.has_next_token
? format_partial_response(to_send)
? format_partial_response(to_send, token_text_with_probs)
// Generation is done, send extra information.
: format_final_response(llama, to_send);
: format_final_response(llama, to_send, {token_text_with_probs});
const std::string str =
"data: " +