1) change c++14 request from global to grpc-server only

2) change proto to llama/v1 dir according to lint suggestion
This commit is contained in:
Liu Ming 2023-06-12 09:57:47 +08:00
parent 837f04b870
commit 428e734bb6
6 changed files with 68 additions and 45 deletions

View file

@ -116,7 +116,11 @@ endif()
# Compile flags # Compile flags
# #
if(LLAMA_BUILD_GRPC_SERVER)
set(CMAKE_CXX_STANDARD 11)
else()
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
endif()
set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_CXX_STANDARD_REQUIRED true)
set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true) set(CMAKE_C_STANDARD_REQUIRED true)

View file

@ -15,7 +15,7 @@ message(STATUS "Using protobuf ${Protobuf_VERSION} ${Protobuf_INCLUDE_DIRS} ${CM
# Proto file # Proto file
get_filename_component(hw_proto "./message.proto" ABSOLUTE) get_filename_component(hw_proto "./llama/v1/message.proto" ABSOLUTE)
get_filename_component(hw_proto_path "${hw_proto}" PATH) get_filename_component(hw_proto_path "${hw_proto}" PATH)
# Generated sources # Generated sources

View file

@ -0,0 +1,14 @@
# llama grpc server
service as a grpc server to completion and embedding (when `--embedding` argument is given) based on examples/server.
## running service
run grpc-server command using argument like main program of llama.cpp with the following change:
* add `--host` argument to set the listening host
* add `--port` argument to set the listening port
### behaving differences with examples/server
* grpc-server will always break when <eos> is the predicted token.

View file

@ -41,9 +41,10 @@ using grpc::ServerContext;
using grpc::ServerUnaryReactor; using grpc::ServerUnaryReactor;
using grpc::ServerWriteReactor; using grpc::ServerWriteReactor;
using grpc::Status; using grpc::Status;
using llama::Job; using llama::v1::Request;
using llama::LlamaGoService; using llama::v1::LlamaService;
using llama::Output; using llama::v1::EmbedResponse;
using llama::v1::CompletionResponse;
struct server_params struct server_params
{ {
@ -497,13 +498,13 @@ private:
}; };
// Logic and data behind the server's behavior. // Logic and data behind the server's behavior.
class LlamaServiceImpl final : public LlamaGoService::CallbackService class LlamaServiceImpl final : public LlamaService::CallbackService
{ {
class Reactor : public grpc::ServerWriteReactor<Output> class Reactor : public grpc::ServerWriteReactor<CompletionResponse>
{ {
public: public:
Reactor(CallbackServerContext *ctx, LlamaServerContext *llama, const Job *request) Reactor(CallbackServerContext *ctx, LlamaServerContext *llama, const Request *request)
: ctx_(ctx), request_(request), llama_(llama) : ctx_(ctx), request_(request), llama_(llama)
{ {
if (llama->loadPrompt(request->prompt())) if (llama->loadPrompt(request->prompt()))
@ -534,22 +535,22 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
private: private:
CallbackServerContext *const ctx_; CallbackServerContext *const ctx_;
LlamaServerContext *llama_; LlamaServerContext *llama_;
const Job *const request_; const Request *const request_;
int n_remain{0}; int n_remain{0};
std::mutex finish_mu_; std::mutex finish_mu_;
bool finished_{false}; bool finished_{false};
Output *response; CompletionResponse *response;
void NextWrite() void NextWrite()
{ {
response = new Output(); response = new CompletionResponse();
// loop inference until finish completion // loop inference until finish completion
if (llama_->has_next_token) if (llama_->has_next_token)
{ {
std::lock_guard<std::mutex> l(finish_mu_); std::lock_guard<std::mutex> l(finish_mu_);
auto result = llama_->doCompletion(); auto result = llama_->doCompletion();
fprintf(stderr, "%s", result.c_str()); fprintf(stderr, "%s", result.c_str());
response->set_status(llama::Status::RUNNING); response->set_status(llama::v1::Status::RUNNING);
response->set_output(result); response->set_output(result);
StartWrite(response); StartWrite(response);
} }
@ -560,7 +561,7 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
l(finish_mu_); l(finish_mu_);
if (!finished_) if (!finished_)
{ {
response->set_status(llama::Status::FINISHED); response->set_status(llama::v1::Status::FINISHED);
StartWriteLast(response, grpc::WriteOptions()); StartWriteLast(response, grpc::WriteOptions());
} }
} }
@ -585,8 +586,8 @@ public:
{ {
} }
ServerWriteReactor<Output> *Answer( ServerWriteReactor<CompletionResponse> *Complete(
CallbackServerContext *context, const Job *request) CallbackServerContext *context, const Request *request)
{ {
fprintf(stderr, "%s : new answer request: %s\n", __func__, request->prompt().c_str()); fprintf(stderr, "%s : new answer request: %s\n", __func__, request->prompt().c_str());
llama->rewind(); llama->rewind();
@ -598,7 +599,7 @@ public:
} }
ServerUnaryReactor *Embed( ServerUnaryReactor *Embed(
CallbackServerContext *context, const Job *request, Output *response) CallbackServerContext *context, const Request *request, EmbedResponse *response)
{ {
fprintf(stderr, "%s : get embed %s\n", __func__, request->prompt().c_str()); fprintf(stderr, "%s : get embed %s\n", __func__, request->prompt().c_str());
std::vector<float> embeded = llama->embedding(request->prompt()); std::vector<float> embeded = llama->embedding(request->prompt());

View file

@ -0,0 +1,33 @@
syntax = "proto3";
package llama.v1;
option go_package = "./pkg/grpc";
service LlamaService {
rpc Complete(Request) returns (stream CompletionResponse){}
rpc Embed(Request) returns (EmbedResponse){}
}
message Request {
string id = 1;
string prompt = 2;
}
enum Status {
PENDING_UNSPECIFIED = 0;
RUNNING = 1;
FINISHED = 2;
FAILED = 3;
}
message CompletionResponse {
string id = 1;
Status status = 2;
string output = 3;
}
message EmbedResponse {
string id = 1;
repeated float embed = 2;
}

View file

@ -1,29 +0,0 @@
syntax = "proto3";
package llama;
option go_package = "./pkg/grpc";
service LlamaGoService {
rpc Answer(Job) returns (stream Output){}
rpc Embed(Job) returns (Output){}
}
message Job {
string id = 1;
string prompt = 2;
}
enum Status {
PENDING = 0;
RUNNING = 1;
FINISHED = 2;
FAILED = 3;
}
message Output {
string id = 1;
Status status = 2;
string output = 3;
repeated float embed = 4;
}