make server accept new request

This commit is contained in:
Liu Ming 2023-06-09 13:22:01 +08:00
parent 0773028d52
commit 4f714548f6

View file

@ -250,10 +250,6 @@ public:
n_remain -= new_prompt_len; n_remain -= new_prompt_len;
} }
fprintf(stderr, "embd_inp size %d,%d\n", embd_inp.size(), params.n_ctx); fprintf(stderr, "embd_inp size %d,%d\n", embd_inp.size(), params.n_ctx);
if ((int)embd_inp.size() > params.n_ctx - 4)
{
return false;
}
has_next_token = true; has_next_token = true;
return true; return true;
} }
@ -274,7 +270,7 @@ public:
llama_token nextToken() llama_token nextToken()
{ {
llama_token result = -1; llama_token result = -1;
// fprintf(stderr, "embed size %d,%d,%d\n", embd.size(), embd_inp.size(), n_consumed); // fprintf(stderr, "embed size %d,%d,%d,%d\n", embd.size(), embd_inp.size(), n_consumed,n_past);
if (embd.size() > 0) if (embd.size() > 0)
{ {
if (n_past + (int)embd.size() > params.n_ctx) if (n_past + (int)embd.size() > params.n_ctx)
@ -522,7 +518,7 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
} }
void OnDone() override void OnDone() override
{ {
fprintf(stderr, "completion done"); fprintf(stderr, "completion done\n");
delete this; delete this;
} }
void OnWriteDone(bool /*ok*/) override void OnWriteDone(bool /*ok*/) override
@ -532,8 +528,7 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
} }
void OnCancel() override void OnCancel() override
{ {
fprintf(stderr, "on cancel"); FinishOnce(grpc::Status::OK);
delete this;
} }
private: private:
@ -543,19 +538,20 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
int n_remain{0}; int n_remain{0};
std::mutex finish_mu_; std::mutex finish_mu_;
bool finished_{false}; bool finished_{false};
Output response; Output *response;
void NextWrite() void NextWrite()
{ {
response = new Output();
// loop inference until finish completion // loop inference until finish completion
if (llama_->has_next_token) if (llama_->has_next_token)
{ {
std::lock_guard<std::mutex> l(finish_mu_); std::lock_guard<std::mutex> l(finish_mu_);
auto result = llama_->doCompletion(); auto result = llama_->doCompletion();
fprintf(stderr, "%s", result.c_str()); fprintf(stderr, "%s", result.c_str());
response.set_status(llama::Status::RUNNING); response->set_status(llama::Status::RUNNING);
response.set_output(result); response->set_output(result);
StartWrite(&response); StartWrite(response);
} }
else else
{ {
@ -564,8 +560,8 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
l(finish_mu_); l(finish_mu_);
if (!finished_) if (!finished_)
{ {
response.set_status(llama::Status::FINISHED); response->set_status(llama::Status::FINISHED);
StartWriteLast(&response, grpc::WriteOptions()); StartWriteLast(response, grpc::WriteOptions());
} }
} }
// If we use WriteLast, we shouldn't wait before attempting Finish // If we use WriteLast, we shouldn't wait before attempting Finish