diff --git a/examples/grpc-server/grpc-server.cpp b/examples/grpc-server/grpc-server.cpp index 35a467dca..a2c7be005 100644 --- a/examples/grpc-server/grpc-server.cpp +++ b/examples/grpc-server/grpc-server.cpp @@ -88,8 +88,14 @@ public: LlamaServerContext(gpt_params params_) : params(params_), threads(8) { - ctx = llama_init_from_gpt_params(params); - if (ctx == NULL) + bool has_embedding = params.embedding; + if (params.embedding) + { + ctx_for_embedding = llama_init_from_gpt_params(params); + } + prams.embedding = false; + ctx_for_completion = llama_init_from_gpt_params(params); + if (ctx_for_completion == NULL || (has_embedding && ctx_for_embedding == NULL)) { loaded = false; fprintf(stderr, "%s: error: unable to load model\n", __func__); @@ -107,26 +113,110 @@ public: std::vector embedding(std::string content) { content.insert(0, 1, ' '); - std::vector tokens = ::llama_tokenize(ctx, content, true); + std::vector tokens = ::llama_tokenize(ctx_for_embedding, content, true); if (tokens.size() > 0) { - fprintf(stderr, "---3---,%p,%d", ctx, threads); - if (llama_eval(ctx, tokens.data(), tokens.size(), 0, 6)) + if (llama_eval(ctx_for_embedding, tokens.data(), tokens.size(), 0, 6)) { fprintf(stderr, "%s : failed to eval\n", __func__); std::vector embeddings_; return embeddings_; } } - const int n_embd = llama_n_embd(ctx); - const auto embeddings = llama_get_embeddings(ctx); + const int n_embd = llama_n_embd(ctx_for_embedding); + const auto embeddings = llama_get_embeddings(ctx_for_embedding); std::vector embeddings_(embeddings, embeddings + n_embd); return embeddings_; } + bool complete(std::string content, int *n_remain, llama_token &result) + { + + const float temp = params.temp; + const int mirostat = params.mirostat; + const bool penalize_nl = params.penalize_nl; + + auto logits = llama_get_logits(ctx_for_completion); + auto n_vocab = llama_n_vocab(ctx_for_completion); + + std::vector candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) + { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; + + // Apply penalties + float nl_logit = logits[llama_token_nl()]; + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); + llama_sample_repetition_penalty(ctx_for_completion, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, repeat_penalty); + llama_sample_frequency_and_presence_penalties(ctx_for_completion, &candidates_p, + last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + if (!penalize_nl) + { + logits[llama_token_nl()] = nl_logit; + } + + if (temp <= 0) + { + // Greedy sampling + id = llama_sample_token_greedy(ctx_for_completion, &candidates_p); + } + else + { + if (mirostat == 1) + { + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; + llama_sample_temperature(ctx_for_completion, &candidates_p, temp); + id = llama_sample_token_mirostat(ctx_for_completion, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + } + else if (mirostat == 2) + { + static float mirostat_mu = 2.0f * mirostat_tau; + llama_sample_temperature(ctx_for_completion, &candidates_p, temp); + id = llama_sample_token_mirostat_v2(ctx_for_completion, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } + else + { + // Temperature sampling + llama_sample_tail_free(ctx_for_completion, &candidates_p, tfs_z, 1); + llama_sample_typical(ctx_for_completion, &candidates_p, typical_p, 1); + llama_sample_top_p(ctx_for_completion, &candidates_p, top_p, 1); + llama_sample_temperature(ctx_for_completion, &candidates_p, temp); + id = llama_sample_token(ctx_for_completion, &candidates_p); + } + } + + --n_remain; + return id == llama_token_eos() || n_remain <= 0; + } + + std::string tokenToString(llama_token token) + { + if (token == llama_token_eos()) + { + return "" + } + else if (token == llama_token_nl()) + { + return "\n"; + } + else + { + return std::string(llama_token_to_str(ctx_for_completion, token)); + } + } + private: gpt_params params; - llama_context *ctx; + llama_context *ctx_for_completion; + llama_context *ctx_for_embedding; int threads; std::vector last_n_tokens; @@ -143,10 +233,32 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService Reactor(CallbackServerContext *ctx, const Job *request) : ctx_(ctx), request_(request) { + content.insert(0, 1, ' '); + std::vector tokens = ::llama_tokenize(ctx_for_completion, content, true); + if (tokens.size() > 0) + { + if (llama_eval(ctx_for_completion, tokens.data(), tokens.size(), 0, 6)) + { + fprintf(stderr, "%s : failed to eval\n", __func__); + return ""; + } + } + // input done, begin to generate + // generate loop + n_remain = params.n_predict; + bool finished = false; + do + { + llama_token* words; + auto finished = llama->complete(request->prompt(),&n_remain, words); + Output response; + response.set_output(llama->tokenToString(words)); + StartWrite(&response); + } while (!finished) + Output response; - // StartWrite(&response_); - // StartWriteLast(&response_, WriteOptions()); - // ctx_->TryCancel(); + StartWriteLast(&response, WriteOptions()); + ctx_->TryCancel(); } void OnDone() override { delete this; } @@ -165,6 +277,7 @@ public: CallbackServerContext *context, const Job *request) { fprintf(stderr, "%s : get answer\n", __func__); + std::vector embeded = llama->complete(request->prompt()); return new Reactor(context, request); } @@ -173,16 +286,10 @@ public: { fprintf(stderr, "%s : get embed %s\n", __func__, request->prompt().c_str()); std::vector embeded = llama->embedding(request->prompt()); - fprintf(stderr, "0"); - fprintf(stderr, "%p", embeded.begin()); *response->mutable_embed() = {embeded.begin(), embeded.end()}; - fprintf(stderr, "1"); response->set_id(request->id()); - fprintf(stderr, "2"); ServerUnaryReactor *reactor = context->DefaultReactor(); - fprintf(stderr, "3"); reactor->Finish(Status::OK); - fprintf(stderr, "4"); return reactor; }