update grpc impl
This commit is contained in:
parent
7a91429897
commit
1fdd8ac615
3 changed files with 399 additions and 97 deletions
|
@ -1,9 +1,9 @@
|
||||||
set(TARGET grpc-server)
|
set(TARGET grpc-server)
|
||||||
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
|
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
|
||||||
set(_REFLECTION grpc++_reflection)
|
set(_REFLECTION grpc++_reflection)
|
||||||
find_package(absl REQUIRED)
|
|
||||||
find_package(Protobuf CONFIG REQUIRED PATHS ${MY_INSTALL_DIR}/lib)
|
|
||||||
include_directories($ENV{MY_INSTALL_DIR}/include)
|
include_directories($ENV{MY_INSTALL_DIR}/include)
|
||||||
|
find_package(absl REQUIRED PATHS $ENV{MY_INSTALL_DIR}/lib)
|
||||||
|
find_package(Protobuf CONFIG REQUIRED PATHS $ENV{MY_INSTALL_DIR}/lib)
|
||||||
find_package(gRPC CONFIG REQUIRED)
|
find_package(gRPC CONFIG REQUIRED)
|
||||||
find_program(_PROTOBUF_PROTOC protoc)
|
find_program(_PROTOBUF_PROTOC protoc)
|
||||||
set(_GRPC_GRPCPP grpc++)
|
set(_GRPC_GRPCPP grpc++)
|
||||||
|
|
|
@ -44,9 +44,9 @@ using grpc::ServerContext;
|
||||||
using grpc::ServerUnaryReactor;
|
using grpc::ServerUnaryReactor;
|
||||||
using grpc::ServerWriteReactor;
|
using grpc::ServerWriteReactor;
|
||||||
using grpc::Status;
|
using grpc::Status;
|
||||||
using robot::Job;
|
using llama::Job;
|
||||||
using robot::LlamaGoService;
|
using llama::LlamaGoService;
|
||||||
using robot::Output;
|
using llama::Output;
|
||||||
|
|
||||||
struct server_params
|
struct server_params
|
||||||
{
|
{
|
||||||
|
@ -85,6 +85,20 @@ class LlamaServerContext
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
bool loaded;
|
bool loaded;
|
||||||
|
bool has_next_token{false};
|
||||||
|
int32_t num_tokens_predicted{0};
|
||||||
|
int32_t n_past{0};
|
||||||
|
int32_t n_consumed{0};
|
||||||
|
int32_t n_session_consumed{0};
|
||||||
|
int32_t n_remain{0};
|
||||||
|
|
||||||
|
std::vector<llama_token> embd;
|
||||||
|
std::vector<llama_token> last_n_tokens;
|
||||||
|
std::vector<llama_token> processed_tokens;
|
||||||
|
std::vector<llama_token> llama_token_newline;
|
||||||
|
std::vector<llama_token> embd_inp;
|
||||||
|
std::vector<std::vector<llama_token>> no_show_words;
|
||||||
|
std::vector<llama_token> tokens_predicted;
|
||||||
|
|
||||||
LlamaServerContext(gpt_params params_) : params(params_), threads(8)
|
LlamaServerContext(gpt_params params_) : params(params_), threads(8)
|
||||||
{
|
{
|
||||||
|
@ -93,9 +107,9 @@ public:
|
||||||
{
|
{
|
||||||
ctx_for_embedding = llama_init_from_gpt_params(params);
|
ctx_for_embedding = llama_init_from_gpt_params(params);
|
||||||
}
|
}
|
||||||
prams.embedding = false;
|
params.embedding = false;
|
||||||
ctx_for_completion = llama_init_from_gpt_params(params);
|
ctx = llama_init_from_gpt_params(params);
|
||||||
if (ctx_for_completion == NULL || (has_embedding && ctx_for_embedding == NULL))
|
if (ctx == NULL || (has_embedding && ctx_for_embedding == NULL))
|
||||||
{
|
{
|
||||||
loaded = false;
|
loaded = false;
|
||||||
fprintf(stderr, "%s: error: unable to load model\n", __func__);
|
fprintf(stderr, "%s: error: unable to load model\n", __func__);
|
||||||
|
@ -103,8 +117,6 @@ public:
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
loaded = true;
|
loaded = true;
|
||||||
// determine newline token
|
|
||||||
llama_token_newline = ::llama_tokenize(ctx, "\n", false);
|
|
||||||
last_n_tokens.resize(params.n_ctx);
|
last_n_tokens.resize(params.n_ctx);
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
}
|
}
|
||||||
|
@ -129,79 +141,332 @@ public:
|
||||||
return embeddings_;
|
return embeddings_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool complete(std::string content, int *n_remain, llama_token &result)
|
void rewind()
|
||||||
{
|
{
|
||||||
|
// as_loop = false;
|
||||||
|
params.antiprompt.clear();
|
||||||
|
no_show_words.clear();
|
||||||
|
num_tokens_predicted = 0;
|
||||||
|
// generated_text = "";
|
||||||
|
}
|
||||||
|
|
||||||
const float temp = params.temp;
|
std::string doCompletion()
|
||||||
const int mirostat = params.mirostat;
|
{
|
||||||
const bool penalize_nl = params.penalize_nl;
|
llama_token token = nextToken();
|
||||||
|
if (token == -1)
|
||||||
auto logits = llama_get_logits(ctx_for_completion);
|
|
||||||
auto n_vocab = llama_n_vocab(ctx_for_completion);
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
|
|
||||||
{
|
{
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
return "";
|
||||||
}
|
}
|
||||||
|
tokens_predicted.clear();
|
||||||
|
tokens_predicted.push_back(token);
|
||||||
|
|
||||||
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
|
// Avoid add the no show words to the response
|
||||||
|
for (std::vector<llama_token> word_tokens : no_show_words)
|
||||||
// Apply penalties
|
|
||||||
float nl_logit = logits[llama_token_nl()];
|
|
||||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
|
|
||||||
llama_sample_repetition_penalty(ctx_for_completion, &candidates_p,
|
|
||||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
||||||
last_n_repeat, repeat_penalty);
|
|
||||||
llama_sample_frequency_and_presence_penalties(ctx_for_completion, &candidates_p,
|
|
||||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
||||||
last_n_repeat, alpha_frequency, alpha_presence);
|
|
||||||
if (!penalize_nl)
|
|
||||||
{
|
{
|
||||||
logits[llama_token_nl()] = nl_logit;
|
int match_token = 1;
|
||||||
}
|
if (tokens_predicted.front() == word_tokens.front())
|
||||||
|
|
||||||
if (temp <= 0)
|
|
||||||
{
|
|
||||||
// Greedy sampling
|
|
||||||
id = llama_sample_token_greedy(ctx_for_completion, &candidates_p);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
if (mirostat == 1)
|
|
||||||
{
|
{
|
||||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
bool execute_matching = true;
|
||||||
const int mirostat_m = 100;
|
if (tokens_predicted.size() > 1)
|
||||||
llama_sample_temperature(ctx_for_completion, &candidates_p, temp);
|
{ // if previus tokens had been tested
|
||||||
id = llama_sample_token_mirostat(ctx_for_completion, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
for (int i = 1; i < word_tokens.size(); i++)
|
||||||
|
{
|
||||||
|
if (i >= tokens_predicted.size())
|
||||||
|
{
|
||||||
|
match_token = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (tokens_predicted[i] == word_tokens[i])
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
execute_matching = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (execute_matching)
|
||||||
|
{
|
||||||
|
if (match_token == word_tokens.size())
|
||||||
|
{
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
token = nextToken();
|
||||||
|
tokens_predicted.push_back(token);
|
||||||
|
if (token == word_tokens[match_token])
|
||||||
|
{ // the token follow the sequence
|
||||||
|
match_token++;
|
||||||
|
}
|
||||||
|
else if (match_token < word_tokens.size())
|
||||||
|
{ // no complete all word sequence
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if (mirostat == 2)
|
}
|
||||||
|
|
||||||
|
std::string generated_text = "";
|
||||||
|
for (llama_token tkn : tokens_predicted)
|
||||||
|
{
|
||||||
|
generated_text += llama_token_to_str(ctx, tkn);
|
||||||
|
}
|
||||||
|
return std::string(generated_text);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool loadPrompt(std::string prompt)
|
||||||
|
{
|
||||||
|
// prompt.insert(0, " Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"); // always add a first space
|
||||||
|
prompt.insert(0, 1, ' '); // always add a first space
|
||||||
|
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, prompt, true);
|
||||||
|
// compare the evaluated prompt with the new prompt
|
||||||
|
int new_prompt_len = 0;
|
||||||
|
for (int i = 0; i < prompt_tokens.size(); i++)
|
||||||
|
{
|
||||||
|
if (i < processed_tokens.size() &&
|
||||||
|
processed_tokens[i] == prompt_tokens[i])
|
||||||
{
|
{
|
||||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
continue;
|
||||||
llama_sample_temperature(ctx_for_completion, &candidates_p, temp);
|
|
||||||
id = llama_sample_token_mirostat_v2(ctx_for_completion, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
// Temperature sampling
|
embd_inp.push_back(prompt_tokens[i]);
|
||||||
llama_sample_tail_free(ctx_for_completion, &candidates_p, tfs_z, 1);
|
if (new_prompt_len == 0)
|
||||||
llama_sample_typical(ctx_for_completion, &candidates_p, typical_p, 1);
|
{
|
||||||
llama_sample_top_p(ctx_for_completion, &candidates_p, top_p, 1);
|
if (i - 1 < n_past)
|
||||||
llama_sample_temperature(ctx_for_completion, &candidates_p, temp);
|
{
|
||||||
id = llama_sample_token(ctx_for_completion, &candidates_p);
|
processed_tokens.erase(processed_tokens.begin() + i, processed_tokens.end());
|
||||||
|
}
|
||||||
|
// Evaluate the new fragment prompt from the last token processed.
|
||||||
|
n_past = processed_tokens.size();
|
||||||
|
}
|
||||||
|
new_prompt_len++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (n_past > 0 && params.interactive)
|
||||||
|
{
|
||||||
|
n_remain -= new_prompt_len;
|
||||||
|
}
|
||||||
|
if ((int)embd_inp.size() > params.n_ctx - 4)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
has_next_token = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void beginCompletion()
|
||||||
|
{
|
||||||
|
if (n_remain == 0)
|
||||||
|
{
|
||||||
|
// number of tokens to keep when resetting context
|
||||||
|
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size())
|
||||||
|
{
|
||||||
|
params.n_keep = (int)embd_inp.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n_remain = params.n_predict;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token nextToken()
|
||||||
|
{
|
||||||
|
llama_token result = -1;
|
||||||
|
if (embd.size() > 0)
|
||||||
|
{
|
||||||
|
if (n_past + (int)embd.size() > params.n_ctx)
|
||||||
|
{
|
||||||
|
// Reset context
|
||||||
|
const int n_left = n_past - params.n_keep;
|
||||||
|
n_past = std::max(1, params.n_keep);
|
||||||
|
processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end());
|
||||||
|
embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size());
|
||||||
|
}
|
||||||
|
for (int i = 0; i < (int)embd.size(); i += params.n_batch)
|
||||||
|
{
|
||||||
|
int n_eval = (int)embd.size() - i;
|
||||||
|
if (n_eval > params.n_batch)
|
||||||
|
{
|
||||||
|
n_eval = params.n_batch;
|
||||||
|
}
|
||||||
|
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads))
|
||||||
|
{
|
||||||
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
|
has_next_token = false;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
n_past += n_eval;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
embd.clear();
|
||||||
|
if ((int)embd_inp.size() <= n_consumed && has_next_token)
|
||||||
|
{
|
||||||
|
// out of user input, sample next token
|
||||||
|
const float temp = params.temp;
|
||||||
|
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
||||||
|
const float top_p = params.top_p;
|
||||||
|
const float tfs_z = params.tfs_z;
|
||||||
|
const float typical_p = params.typical_p;
|
||||||
|
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
|
||||||
|
const float repeat_penalty = params.repeat_penalty;
|
||||||
|
const float alpha_presence = params.presence_penalty;
|
||||||
|
const float alpha_frequency = params.frequency_penalty;
|
||||||
|
const int mirostat = params.mirostat;
|
||||||
|
const float mirostat_tau = params.mirostat_tau;
|
||||||
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
|
const bool penalize_nl = params.penalize_nl;
|
||||||
|
llama_token id = 0;
|
||||||
|
{
|
||||||
|
auto logits = llama_get_logits(ctx);
|
||||||
|
auto n_vocab = llama_n_vocab(ctx);
|
||||||
|
|
||||||
|
// Apply params.logit_bias map
|
||||||
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++)
|
||||||
|
{
|
||||||
|
logits[it->first] += it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
candidates.reserve(n_vocab);
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
|
||||||
|
{
|
||||||
|
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
|
||||||
|
|
||||||
|
// Apply penalties
|
||||||
|
float nl_logit = logits[llama_token_nl()];
|
||||||
|
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
|
||||||
|
llama_sample_repetition_penalty(ctx, &candidates_p,
|
||||||
|
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||||
|
last_n_repeat, repeat_penalty);
|
||||||
|
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
|
||||||
|
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||||
|
last_n_repeat, alpha_frequency, alpha_presence);
|
||||||
|
if (!penalize_nl)
|
||||||
|
{
|
||||||
|
logits[llama_token_nl()] = nl_logit;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (temp <= 0)
|
||||||
|
{
|
||||||
|
// Greedy sampling
|
||||||
|
id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (mirostat == 1)
|
||||||
|
{
|
||||||
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
|
const int mirostat_m = 100;
|
||||||
|
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||||
|
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||||
|
}
|
||||||
|
else if (mirostat == 2)
|
||||||
|
{
|
||||||
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
|
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||||
|
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Temperature sampling
|
||||||
|
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
|
||||||
|
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
|
||||||
|
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||||
|
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||||
|
id = llama_sample_token(ctx, &candidates_p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
|
last_n_tokens.push_back(id);
|
||||||
|
processed_tokens.push_back(id);
|
||||||
|
num_tokens_predicted++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// replace end of text token with newline token when in interactive mode
|
||||||
|
if (id == llama_token_eos() && params.interactive)
|
||||||
|
{
|
||||||
|
id = llama_token_newline.front();
|
||||||
|
if (params.antiprompt.size() != 0)
|
||||||
|
{
|
||||||
|
// tokenize and inject first reverse prompt
|
||||||
|
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
|
||||||
|
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add it to the context
|
||||||
|
embd.push_back(id);
|
||||||
|
for (auto id : embd)
|
||||||
|
{
|
||||||
|
result = id;
|
||||||
|
}
|
||||||
|
// decrement remaining sampling budget
|
||||||
|
--n_remain;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// some user input remains from prompt or interaction, forward it to processing
|
||||||
|
while ((int)embd_inp.size() > n_consumed)
|
||||||
|
{
|
||||||
|
embd.push_back(embd_inp[n_consumed]);
|
||||||
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
|
last_n_tokens.push_back(embd_inp[n_consumed]);
|
||||||
|
processed_tokens.push_back(embd_inp[n_consumed]);
|
||||||
|
++n_consumed;
|
||||||
|
if ((int)embd.size() >= params.n_batch)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (params.interactive && (int)embd_inp.size() <= n_consumed)
|
||||||
|
{
|
||||||
|
// check for reverse prompt
|
||||||
|
if (params.antiprompt.size())
|
||||||
|
{
|
||||||
|
std::string last_output;
|
||||||
|
for (auto id : last_n_tokens)
|
||||||
|
{
|
||||||
|
last_output += llama_token_to_str(ctx, id);
|
||||||
|
}
|
||||||
|
has_next_token = true;
|
||||||
|
// Check if each of the reverse prompts appears at the end of the output.
|
||||||
|
for (std::string &antiprompt : params.antiprompt)
|
||||||
|
{
|
||||||
|
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos)
|
||||||
|
{
|
||||||
|
has_next_token = false;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (n_past > 0)
|
||||||
|
{
|
||||||
|
has_next_token = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
--n_remain;
|
if (!embd.empty() && embd.back() == llama_token_eos())
|
||||||
return id == llama_token_eos() || n_remain <= 0;
|
{
|
||||||
|
has_next_token = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.interactive && n_remain <= 0 && params.n_predict != -1)
|
||||||
|
{
|
||||||
|
n_remain = params.n_predict;
|
||||||
|
}
|
||||||
|
has_next_token = n_remain != 0;
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string tokenToString(llama_token token)
|
std::string tokenToString(llama_token token)
|
||||||
{
|
{
|
||||||
if (token == llama_token_eos())
|
if (token == llama_token_eos())
|
||||||
{
|
{
|
||||||
return ""
|
return "";
|
||||||
}
|
}
|
||||||
else if (token == llama_token_nl())
|
else if (token == llama_token_nl())
|
||||||
{
|
{
|
||||||
|
@ -209,18 +474,19 @@ public:
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
return std::string(llama_token_to_str(ctx_for_completion, token));
|
return std::string(llama_token_to_str(ctx, token));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
llama_context *ctx_for_completion;
|
llama_context *ctx;
|
||||||
llama_context *ctx_for_embedding;
|
llama_context *ctx_for_embedding;
|
||||||
int threads;
|
int threads;
|
||||||
|
int n_ctx;
|
||||||
|
|
||||||
std::vector<llama_token> last_n_tokens;
|
// std::vector<llama_token> last_n_tokens;
|
||||||
std::vector<llama_token> llama_token_newline;
|
// std::vector<llama_token> llama_token_newline;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Logic and data behind the server's behavior.
|
// Logic and data behind the server's behavior.
|
||||||
|
@ -230,41 +496,72 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
|
||||||
class Reactor : public grpc::ServerWriteReactor<Output>
|
class Reactor : public grpc::ServerWriteReactor<Output>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Reactor(CallbackServerContext *ctx, const Job *request)
|
Reactor(CallbackServerContext *ctx, LlamaServerContext *llama, const Job *request)
|
||||||
: ctx_(ctx), request_(request)
|
: ctx_(ctx), request_(request), llama_(llama)
|
||||||
{
|
{
|
||||||
content.insert(0, 1, ' ');
|
if (llama->loadPrompt(request->prompt()))
|
||||||
std::vector<llama_token> tokens = ::llama_tokenize(ctx_for_completion, content, true);
|
|
||||||
if (tokens.size() > 0)
|
|
||||||
{
|
{
|
||||||
if (llama_eval(ctx_for_completion, tokens.data(), tokens.size(), 0, 6))
|
llama->beginCompletion();
|
||||||
{
|
NextWrite();
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
}
|
||||||
return "";
|
else
|
||||||
}
|
{
|
||||||
|
Finish(grpc::Status::OK);
|
||||||
}
|
}
|
||||||
// input done, begin to generate
|
|
||||||
// generate loop
|
|
||||||
n_remain = params.n_predict;
|
|
||||||
bool finished = false;
|
|
||||||
do
|
|
||||||
{
|
|
||||||
llama_token* words;
|
|
||||||
auto finished = llama->complete(request->prompt(),&n_remain, words);
|
|
||||||
Output response;
|
|
||||||
response.set_output(llama->tokenToString(words));
|
|
||||||
StartWrite(&response);
|
|
||||||
} while (!finished)
|
|
||||||
|
|
||||||
Output response;
|
|
||||||
StartWriteLast(&response, WriteOptions());
|
|
||||||
ctx_->TryCancel();
|
|
||||||
}
|
}
|
||||||
void OnDone() override { delete this; }
|
void OnDone() override
|
||||||
|
{
|
||||||
|
fprintf(stderr, "completion done");
|
||||||
|
delete this;
|
||||||
|
}
|
||||||
|
void OnWriteDone(bool /*ok*/) override
|
||||||
|
{
|
||||||
|
// fprintf(stderr, "on write done");
|
||||||
|
NextWrite();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CallbackServerContext *const ctx_;
|
CallbackServerContext *const ctx_;
|
||||||
|
LlamaServerContext *llama_;
|
||||||
const Job *const request_;
|
const Job *const request_;
|
||||||
|
int n_remain{0};
|
||||||
|
std::mutex finish_mu_;
|
||||||
|
bool finished_{false};
|
||||||
|
Output response;
|
||||||
|
|
||||||
|
void NextWrite()
|
||||||
|
{
|
||||||
|
// loop inference until finish completion
|
||||||
|
if (llama_->has_next_token)
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> l(finish_mu_);
|
||||||
|
auto result = llama_->doCompletion();
|
||||||
|
fprintf(stderr, "%s", result.c_str());
|
||||||
|
response.set_output(result);
|
||||||
|
StartWrite(&response);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
{
|
||||||
|
response.set_status(llama::Status::FINISHED);
|
||||||
|
std::lock_guard<std::mutex>
|
||||||
|
l(finish_mu_);
|
||||||
|
StartWriteLast(&response, grpc::WriteOptions());
|
||||||
|
}
|
||||||
|
// If we use WriteLast, we shouldn't wait before attempting Finish
|
||||||
|
FinishOnce(Status::OK);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FinishOnce(const Status &s)
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> l(finish_mu_);
|
||||||
|
if (!finished_)
|
||||||
|
{
|
||||||
|
Finish(s);
|
||||||
|
finished_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -277,8 +574,12 @@ public:
|
||||||
CallbackServerContext *context, const Job *request)
|
CallbackServerContext *context, const Job *request)
|
||||||
{
|
{
|
||||||
fprintf(stderr, "%s : get answer\n", __func__);
|
fprintf(stderr, "%s : get answer\n", __func__);
|
||||||
std::vector<float> embeded = llama->complete(request->prompt());
|
llama->rewind();
|
||||||
return new Reactor(context, request);
|
// std::vector<float> embeded = llama->complete(request->prompt());
|
||||||
|
Reactor *reactor = new Reactor(context, llama, request);
|
||||||
|
// reactors.push_back(reactor);
|
||||||
|
|
||||||
|
return reactor;
|
||||||
}
|
}
|
||||||
|
|
||||||
ServerUnaryReactor *Embed(
|
ServerUnaryReactor *Embed(
|
||||||
|
@ -295,6 +596,7 @@ public:
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LlamaServerContext *llama;
|
LlamaServerContext *llama;
|
||||||
|
// std::vector<Reactor> reactors;
|
||||||
int threads;
|
int threads;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -433,7 +735,7 @@ int main(int argc, char **argv)
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
params.embedding = true;
|
// params.embedding = true;
|
||||||
|
|
||||||
if (params.seed <= 0)
|
if (params.seed <= 0)
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
package robot;
|
package llama;
|
||||||
|
|
||||||
option go_package = "./pkg/grpc";
|
option go_package = "./pkg/grpc";
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue