diff --git a/examples/grpc-server/CMakeLists.txt b/examples/grpc-server/CMakeLists.txt index 51540ca4b..1ecfb6f7c 100644 --- a/examples/grpc-server/CMakeLists.txt +++ b/examples/grpc-server/CMakeLists.txt @@ -1,9 +1,9 @@ set(TARGET grpc-server) set(_PROTOBUF_LIBPROTOBUF libprotobuf) set(_REFLECTION grpc++_reflection) -find_package(absl REQUIRED) -find_package(Protobuf CONFIG REQUIRED PATHS ${MY_INSTALL_DIR}/lib) include_directories($ENV{MY_INSTALL_DIR}/include) +find_package(absl REQUIRED PATHS $ENV{MY_INSTALL_DIR}/lib) +find_package(Protobuf CONFIG REQUIRED PATHS $ENV{MY_INSTALL_DIR}/lib) find_package(gRPC CONFIG REQUIRED) find_program(_PROTOBUF_PROTOC protoc) set(_GRPC_GRPCPP grpc++) diff --git a/examples/grpc-server/grpc-server.cpp b/examples/grpc-server/grpc-server.cpp index a2c7be005..901edfd76 100644 --- a/examples/grpc-server/grpc-server.cpp +++ b/examples/grpc-server/grpc-server.cpp @@ -44,9 +44,9 @@ using grpc::ServerContext; using grpc::ServerUnaryReactor; using grpc::ServerWriteReactor; using grpc::Status; -using robot::Job; -using robot::LlamaGoService; -using robot::Output; +using llama::Job; +using llama::LlamaGoService; +using llama::Output; struct server_params { @@ -85,6 +85,20 @@ class LlamaServerContext { public: bool loaded; + bool has_next_token{false}; + int32_t num_tokens_predicted{0}; + int32_t n_past{0}; + int32_t n_consumed{0}; + int32_t n_session_consumed{0}; + int32_t n_remain{0}; + + std::vector embd; + std::vector last_n_tokens; + std::vector processed_tokens; + std::vector llama_token_newline; + std::vector embd_inp; + std::vector> no_show_words; + std::vector tokens_predicted; LlamaServerContext(gpt_params params_) : params(params_), threads(8) { @@ -93,9 +107,9 @@ public: { ctx_for_embedding = llama_init_from_gpt_params(params); } - prams.embedding = false; - ctx_for_completion = llama_init_from_gpt_params(params); - if (ctx_for_completion == NULL || (has_embedding && ctx_for_embedding == NULL)) + params.embedding = false; + ctx = llama_init_from_gpt_params(params); + if (ctx == NULL || (has_embedding && ctx_for_embedding == NULL)) { loaded = false; fprintf(stderr, "%s: error: unable to load model\n", __func__); @@ -103,8 +117,6 @@ public: else { loaded = true; - // determine newline token - llama_token_newline = ::llama_tokenize(ctx, "\n", false); last_n_tokens.resize(params.n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); } @@ -129,79 +141,332 @@ public: return embeddings_; } - bool complete(std::string content, int *n_remain, llama_token &result) + void rewind() { + // as_loop = false; + params.antiprompt.clear(); + no_show_words.clear(); + num_tokens_predicted = 0; + // generated_text = ""; + } - const float temp = params.temp; - const int mirostat = params.mirostat; - const bool penalize_nl = params.penalize_nl; - - auto logits = llama_get_logits(ctx_for_completion); - auto n_vocab = llama_n_vocab(ctx_for_completion); - - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) + std::string doCompletion() + { + llama_token token = nextToken(); + if (token == -1) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + return ""; } + tokens_predicted.clear(); + tokens_predicted.push_back(token); - llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; - - // Apply penalties - float nl_logit = logits[llama_token_nl()]; - auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); - llama_sample_repetition_penalty(ctx_for_completion, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx_for_completion, &candidates_p, - last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - if (!penalize_nl) + // Avoid add the no show words to the response + for (std::vector word_tokens : no_show_words) { - logits[llama_token_nl()] = nl_logit; - } - - if (temp <= 0) - { - // Greedy sampling - id = llama_sample_token_greedy(ctx_for_completion, &candidates_p); - } - else - { - if (mirostat == 1) + int match_token = 1; + if (tokens_predicted.front() == word_tokens.front()) { - static float mirostat_mu = 2.0f * mirostat_tau; - const int mirostat_m = 100; - llama_sample_temperature(ctx_for_completion, &candidates_p, temp); - id = llama_sample_token_mirostat(ctx_for_completion, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + bool execute_matching = true; + if (tokens_predicted.size() > 1) + { // if previus tokens had been tested + for (int i = 1; i < word_tokens.size(); i++) + { + if (i >= tokens_predicted.size()) + { + match_token = i; + break; + } + if (tokens_predicted[i] == word_tokens[i]) + { + continue; + } + else + { + execute_matching = false; + break; + } + } + } + while (execute_matching) + { + if (match_token == word_tokens.size()) + { + return ""; + } + token = nextToken(); + tokens_predicted.push_back(token); + if (token == word_tokens[match_token]) + { // the token follow the sequence + match_token++; + } + else if (match_token < word_tokens.size()) + { // no complete all word sequence + break; + } + } } - else if (mirostat == 2) + } + + std::string generated_text = ""; + for (llama_token tkn : tokens_predicted) + { + generated_text += llama_token_to_str(ctx, tkn); + } + return std::string(generated_text); + } + + bool loadPrompt(std::string prompt) + { + // prompt.insert(0, " Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"); // always add a first space + prompt.insert(0, 1, ' '); // always add a first space + std::vector prompt_tokens = ::llama_tokenize(ctx, prompt, true); + // compare the evaluated prompt with the new prompt + int new_prompt_len = 0; + for (int i = 0; i < prompt_tokens.size(); i++) + { + if (i < processed_tokens.size() && + processed_tokens[i] == prompt_tokens[i]) { - static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx_for_completion, &candidates_p, temp); - id = llama_sample_token_mirostat_v2(ctx_for_completion, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + continue; } else { - // Temperature sampling - llama_sample_tail_free(ctx_for_completion, &candidates_p, tfs_z, 1); - llama_sample_typical(ctx_for_completion, &candidates_p, typical_p, 1); - llama_sample_top_p(ctx_for_completion, &candidates_p, top_p, 1); - llama_sample_temperature(ctx_for_completion, &candidates_p, temp); - id = llama_sample_token(ctx_for_completion, &candidates_p); + embd_inp.push_back(prompt_tokens[i]); + if (new_prompt_len == 0) + { + if (i - 1 < n_past) + { + processed_tokens.erase(processed_tokens.begin() + i, processed_tokens.end()); + } + // Evaluate the new fragment prompt from the last token processed. + n_past = processed_tokens.size(); + } + new_prompt_len++; + } + } + if (n_past > 0 && params.interactive) + { + n_remain -= new_prompt_len; + } + if ((int)embd_inp.size() > params.n_ctx - 4) + { + return false; + } + has_next_token = true; + return true; + } + + void beginCompletion() + { + if (n_remain == 0) + { + // number of tokens to keep when resetting context + if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size()) + { + params.n_keep = (int)embd_inp.size(); + } + } + n_remain = params.n_predict; + } + + llama_token nextToken() + { + llama_token result = -1; + if (embd.size() > 0) + { + if (n_past + (int)embd.size() > params.n_ctx) + { + // Reset context + const int n_left = n_past - params.n_keep; + n_past = std::max(1, params.n_keep); + processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end()); + embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size()); + } + for (int i = 0; i < (int)embd.size(); i += params.n_batch) + { + int n_eval = (int)embd.size() - i; + if (n_eval > params.n_batch) + { + n_eval = params.n_batch; + } + if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) + { + fprintf(stderr, "%s : failed to eval\n", __func__); + has_next_token = false; + return result; + } + n_past += n_eval; + } + } + embd.clear(); + if ((int)embd_inp.size() <= n_consumed && has_next_token) + { + // out of user input, sample next token + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k; + const float top_p = params.top_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n; + const float repeat_penalty = params.repeat_penalty; + const float alpha_presence = params.presence_penalty; + const float alpha_frequency = params.frequency_penalty; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; + const bool penalize_nl = params.penalize_nl; + llama_token id = 0; + { + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) + { + logits[it->first] += it->second; + } + + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) + { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; + + // Apply penalties + float nl_logit = logits[llama_token_nl()]; + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); + llama_sample_repetition_penalty(ctx, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); + llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + if (!penalize_nl) + { + logits[llama_token_nl()] = nl_logit; + } + + if (temp <= 0) + { + // Greedy sampling + id = llama_sample_token_greedy(ctx, &candidates_p); + } + else + { + if (mirostat == 1) + { + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + } + else if (mirostat == 2) + { + static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } + else + { + // Temperature sampling + llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); + llama_sample_typical(ctx, &candidates_p, typical_p, 1); + llama_sample_top_p(ctx, &candidates_p, top_p, 1); + llama_sample_temperature(ctx, &candidates_p, temp); + id = llama_sample_token(ctx, &candidates_p); + } + } + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); + processed_tokens.push_back(id); + num_tokens_predicted++; + } + + // replace end of text token with newline token when in interactive mode + if (id == llama_token_eos() && params.interactive) + { + id = llama_token_newline.front(); + if (params.antiprompt.size() != 0) + { + // tokenize and inject first reverse prompt + const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); + embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); + } + } + + // add it to the context + embd.push_back(id); + for (auto id : embd) + { + result = id; + } + // decrement remaining sampling budget + --n_remain; + } + else + { + // some user input remains from prompt or interaction, forward it to processing + while ((int)embd_inp.size() > n_consumed) + { + embd.push_back(embd_inp[n_consumed]); + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(embd_inp[n_consumed]); + processed_tokens.push_back(embd_inp[n_consumed]); + ++n_consumed; + if ((int)embd.size() >= params.n_batch) + { + break; + } + } + } + if (params.interactive && (int)embd_inp.size() <= n_consumed) + { + // check for reverse prompt + if (params.antiprompt.size()) + { + std::string last_output; + for (auto id : last_n_tokens) + { + last_output += llama_token_to_str(ctx, id); + } + has_next_token = true; + // Check if each of the reverse prompts appears at the end of the output. + for (std::string &antiprompt : params.antiprompt) + { + if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) + { + has_next_token = false; + return result; + } + } + } + if (n_past > 0) + { + has_next_token = true; } } - --n_remain; - return id == llama_token_eos() || n_remain <= 0; + if (!embd.empty() && embd.back() == llama_token_eos()) + { + has_next_token = false; + } + + if (params.interactive && n_remain <= 0 && params.n_predict != -1) + { + n_remain = params.n_predict; + } + has_next_token = n_remain != 0; + return result; } std::string tokenToString(llama_token token) { if (token == llama_token_eos()) { - return "" + return ""; } else if (token == llama_token_nl()) { @@ -209,18 +474,19 @@ public: } else { - return std::string(llama_token_to_str(ctx_for_completion, token)); + return std::string(llama_token_to_str(ctx, token)); } } private: gpt_params params; - llama_context *ctx_for_completion; + llama_context *ctx; llama_context *ctx_for_embedding; int threads; + int n_ctx; - std::vector last_n_tokens; - std::vector llama_token_newline; + // std::vector last_n_tokens; + // std::vector llama_token_newline; }; // Logic and data behind the server's behavior. @@ -230,41 +496,72 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService class Reactor : public grpc::ServerWriteReactor { public: - Reactor(CallbackServerContext *ctx, const Job *request) - : ctx_(ctx), request_(request) + Reactor(CallbackServerContext *ctx, LlamaServerContext *llama, const Job *request) + : ctx_(ctx), request_(request), llama_(llama) { - content.insert(0, 1, ' '); - std::vector tokens = ::llama_tokenize(ctx_for_completion, content, true); - if (tokens.size() > 0) + if (llama->loadPrompt(request->prompt())) { - if (llama_eval(ctx_for_completion, tokens.data(), tokens.size(), 0, 6)) - { - fprintf(stderr, "%s : failed to eval\n", __func__); - return ""; - } + llama->beginCompletion(); + NextWrite(); + } + else + { + Finish(grpc::Status::OK); } - // input done, begin to generate - // generate loop - n_remain = params.n_predict; - bool finished = false; - do - { - llama_token* words; - auto finished = llama->complete(request->prompt(),&n_remain, words); - Output response; - response.set_output(llama->tokenToString(words)); - StartWrite(&response); - } while (!finished) - - Output response; - StartWriteLast(&response, WriteOptions()); - ctx_->TryCancel(); } - void OnDone() override { delete this; } + void OnDone() override + { + fprintf(stderr, "completion done"); + delete this; + } + void OnWriteDone(bool /*ok*/) override + { + // fprintf(stderr, "on write done"); + NextWrite(); + } private: CallbackServerContext *const ctx_; + LlamaServerContext *llama_; const Job *const request_; + int n_remain{0}; + std::mutex finish_mu_; + bool finished_{false}; + Output response; + + void NextWrite() + { + // loop inference until finish completion + if (llama_->has_next_token) + { + std::lock_guard l(finish_mu_); + auto result = llama_->doCompletion(); + fprintf(stderr, "%s", result.c_str()); + response.set_output(result); + StartWrite(&response); + } + else + { + { + response.set_status(llama::Status::FINISHED); + std::lock_guard + l(finish_mu_); + StartWriteLast(&response, grpc::WriteOptions()); + } + // If we use WriteLast, we shouldn't wait before attempting Finish + FinishOnce(Status::OK); + } + } + + void FinishOnce(const Status &s) + { + std::lock_guard l(finish_mu_); + if (!finished_) + { + Finish(s); + finished_ = true; + } + } }; public: @@ -277,8 +574,12 @@ public: CallbackServerContext *context, const Job *request) { fprintf(stderr, "%s : get answer\n", __func__); - std::vector embeded = llama->complete(request->prompt()); - return new Reactor(context, request); + llama->rewind(); + // std::vector embeded = llama->complete(request->prompt()); + Reactor *reactor = new Reactor(context, llama, request); + // reactors.push_back(reactor); + + return reactor; } ServerUnaryReactor *Embed( @@ -295,6 +596,7 @@ public: private: LlamaServerContext *llama; + // std::vector reactors; int threads; }; @@ -433,7 +735,7 @@ int main(int argc, char **argv) return 1; } - params.embedding = true; + // params.embedding = true; if (params.seed <= 0) { diff --git a/examples/grpc-server/message.proto b/examples/grpc-server/message.proto index c2ad80be8..f4a608949 100644 --- a/examples/grpc-server/message.proto +++ b/examples/grpc-server/message.proto @@ -1,6 +1,6 @@ syntax = "proto3"; -package robot; +package llama; option go_package = "./pkg/grpc";