add grpc server

This commit is contained in:
刘铭 2023-05-23 17:07:50 +08:00
parent 7e4ea5beff
commit c645f1e8c1
5 changed files with 434 additions and 1 deletions

View file

@ -73,6 +73,7 @@ option(LLAMA_CLBLAST "llama: use CLBlast"
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_SERVER "llama: build server example" OFF) option(LLAMA_BUILD_SERVER "llama: build server example" OFF)
option(LLAMA_BUILD_GRPC_SERVER "llama: build grpc server example" OFF)
# #
# Build info header # Build info header
@ -111,7 +112,7 @@ endif()
# Compile flags # Compile flags
# #
set(CMAKE_CXX_STANDARD 11) set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_CXX_STANDARD_REQUIRED true)
set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true) set(CMAKE_C_STANDARD_REQUIRED true)

View file

@ -40,4 +40,7 @@ else()
if(LLAMA_BUILD_SERVER) if(LLAMA_BUILD_SERVER)
add_subdirectory(server) add_subdirectory(server)
endif() endif()
if(LLAMA_BUILD_GRPC_SERVER)
add_subdirectory(grpc-server)
endif()
endif() endif()

View file

@ -0,0 +1,52 @@
set(TARGET grpc-server)
set(_PROTOBUF_LIBPROTOBUF libprotobuf)
set(_REFLECTION grpc++_reflection)
find_package(absl REQUIRED)
find_package(Protobuf CONFIG REQUIRED PATHS ${MY_INSTALL_DIR}/lib)
include_directories($ENV{MY_INSTALL_DIR}/include)
find_package(gRPC CONFIG REQUIRED)
find_program(_PROTOBUF_PROTOC protoc)
set(_GRPC_GRPCPP grpc++)
find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
message(STATUS "Using protobuf ${Protobuf_VERSION} ${CMAKE_CURRENT_BINARY_DIR} $ENV{MY_INSTALL_DIR}/include")
# Proto file
get_filename_component(hw_proto "./message.proto" ABSOLUTE)
get_filename_component(hw_proto_path "${hw_proto}" PATH)
# Generated sources
set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/message.pb.cc")
set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/message.pb.h")
set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/message.grpc.pb.cc")
set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/message.grpc.pb.h")
add_custom_command(
OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}"
COMMAND ${_PROTOBUF_PROTOC}
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
-I "${hw_proto_path}"
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
"${hw_proto}"
DEPENDS "${hw_proto}")
# hw_grpc_proto
add_library(hw_grpc_proto
${hw_grpc_srcs}
${hw_grpc_hdrs}
${hw_proto_srcs}
${hw_proto_hdrs})
add_executable(${TARGET} grpc-server.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} hw_grpc_proto absl::flags
absl::flags_parse
gRPC::${_REFLECTION}
gRPC::${_GRPC_GRPCPP}
protobuf::${_PROTOBUF_LIBPROTOBUF})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

View file

@ -0,0 +1,348 @@
/*
*
* Copyright 2015 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
#include <iostream>
#include <memory>
#include <string>
#include "common.h"
#include "llama.h"
#include <grpcpp/grpcpp.h>
#include "absl/strings/str_format.h"
#ifdef BAZEL_BUILD
#include "examples/protos/message.grpc.pb.h"
#else
#include "message.grpc.pb.h"
#endif
// ABSL_FLAG(uint16_t, port, 50051, "Server port for the service");
// ABSL_FLAG(std::string, target, "localhost:50051", "Server address");
using grpc::CallbackServerContext;
using grpc::Server;
using grpc::ServerAsyncWriter;
using grpc::ServerBuilder;
using grpc::ServerCompletionQueue;
using grpc::ServerContext;
using grpc::ServerUnaryReactor;
using grpc::ServerWriteReactor;
using grpc::Status;
using robot::Job;
using robot::LlamaGoService;
using robot::Output;
struct server_params
{
std::string hostname = "127.0.0.1";
int32_t port = 8080;
};
void server_print_usage(int /*argc*/, char **argv, const gpt_params &params)
{
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n");
fprintf(stderr, " --embedding enable embedding mode\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
if (llama_mlock_supported())
{
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
if (llama_mmap_supported())
{
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
}
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
fprintf(stderr, " number of layers to store in VRAM\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -host ip address to listen (default 127.0.0.1)\n");
fprintf(stderr, " -port PORT port to listen (default 8080)\n");
fprintf(stderr, "\n");
}
class LlamaServerContext
{
public:
bool loaded;
LlamaServerContext(gpt_params params_) : params(params_), threads(8)
{
ctx = llama_init_from_gpt_params(params);
if (ctx == NULL)
{
loaded = false;
fprintf(stderr, "%s: error: unable to load model\n", __func__);
}
else
{
loaded = true;
// determine newline token
llama_token_newline = ::llama_tokenize(ctx, "\n", false);
last_n_tokens.resize(params.n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
}
}
std::vector<float> embedding(std::string content)
{
content.insert(0, 1, ' ');
std::vector<llama_token> tokens = ::llama_tokenize(ctx, content, true);
if (tokens.size() > 0)
{
fprintf(stderr, "---3---,%p,%d", ctx, threads);
if (llama_eval(ctx, tokens.data(), tokens.size(), 0, 6))
{
fprintf(stderr, "%s : failed to eval\n", __func__);
std::vector<float> embeddings_;
return embeddings_;
}
}
const int n_embd = llama_n_embd(ctx);
const auto embeddings = llama_get_embeddings(ctx);
std::vector<float> embeddings_(embeddings, embeddings + n_embd);
return embeddings_;
}
private:
gpt_params params;
llama_context *ctx;
int threads;
std::vector<llama_token> last_n_tokens;
std::vector<llama_token> llama_token_newline;
};
// Logic and data behind the server's behavior.
class LlamaServiceImpl final : public LlamaGoService::CallbackService
{
class Reactor : public grpc::ServerWriteReactor<Output>
{
public:
Reactor(CallbackServerContext *ctx, const Job *request)
: ctx_(ctx), request_(request)
{
Output response;
// StartWrite(&response_);
// StartWriteLast(&response_, WriteOptions());
// ctx_->TryCancel();
}
void OnDone() override { delete this; }
private:
CallbackServerContext *const ctx_;
const Job *const request_;
};
public:
LlamaServiceImpl(LlamaServerContext *llama_) : llama(llama_)
{
fprintf(stderr, "%s : new impl\n", __func__);
}
ServerWriteReactor<Output> *Answer(
CallbackServerContext *context, const Job *request)
{
fprintf(stderr, "%s : get answer\n", __func__);
return new Reactor(context, request);
}
ServerUnaryReactor *Embed(
CallbackServerContext *context, const Job *request, Output *response)
{
fprintf(stderr, "%s : get embed %s\n", __func__, request->prompt().c_str());
std::vector<float> embeded = llama->embedding(request->prompt());
fprintf(stderr, "0");
fprintf(stderr, "%p", embeded.begin());
*response->mutable_embed() = {embeded.begin(), embeded.end()};
fprintf(stderr, "1");
response->set_id(request->id());
fprintf(stderr, "2");
ServerUnaryReactor *reactor = context->DefaultReactor();
fprintf(stderr, "3");
reactor->Finish(Status::OK);
fprintf(stderr, "4");
return reactor;
}
private:
LlamaServerContext *llama;
int threads;
};
void RunServer(uint16_t port, LlamaServerContext *llama)
{
std::string server_address = absl::StrFormat("0.0.0.0:%d", port);
LlamaServiceImpl service(llama);
grpc::EnableDefaultHealthCheckService(true);
ServerBuilder builder;
// Listen on the given address without any authentication mechanism.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
// Register "service" as the instance through which we'll communicate with
// clients. In this case it corresponds to an *synchronous* service.
builder.RegisterService(&service);
// Finally assemble the server.
std::unique_ptr<Server> server(builder.BuildAndStart());
std::cout << "Server listening on " << server_address << std::endl;
// Wait for the server to shutdown. Note that some other thread must be
// responsible for shutting down the server for this call to ever return.
server->Wait();
}
bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params)
{
gpt_params default_params;
std::string arg;
bool invalid_param = false;
for (int i = 1; i < argc; i++)
{
arg = argv[i];
if (arg == "--port")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
sparams.port = std::stoi(argv[i]);
}
else if (arg == "--host")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
sparams.hostname = argv[i];
}
else if (arg == "-s" || arg == "--seed")
{
#if defined(GGML_USE_CUBLAS)
fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
#endif
if (++i >= argc)
{
invalid_param = true;
break;
}
params.seed = std::stoi(argv[i]);
}
else if (arg == "-m" || arg == "--model")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.model = argv[i];
}
else if (arg == "--embedding")
{
params.embedding = true;
}
else if (arg == "-h" || arg == "--help")
{
server_print_usage(argc, argv, default_params);
exit(0);
}
else if (arg == "-c" || arg == "--ctx_size")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_ctx = std::stoi(argv[i]);
}
else if (arg == "--memory_f32")
{
params.memory_f16 = false;
}
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_gpu_layers = std::stoi(argv[i]);
}
else
{
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
server_print_usage(argc, argv, default_params);
exit(1);
}
}
if (invalid_param)
{
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
server_print_usage(argc, argv, default_params);
exit(1);
}
return true;
}
int main(int argc, char **argv)
{
gpt_params params;
server_params sparams;
llama_init_backend();
params.model = "ggml-model.bin";
params.n_ctx = 512;
sparams.port = 8080;
if (gpt_params_parse(argc, argv, params) == false)
{
return 1;
}
params.embedding = true;
if (params.seed <= 0)
{
params.seed = time(NULL);
}
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
LlamaServerContext llama(params);
// load the model
if (!llama.loaded)
{
return 1;
}
RunServer(sparams.port, &llama);
return 0;
}

View file

@ -0,0 +1,29 @@
syntax = "proto3";
package robot;
option go_package = "./pkg/grpc";
service LlamaGoService {
rpc Answer(Job) returns (stream Output){}
rpc Embed(Job) returns (Output){}
}
message Job {
string id = 1;
string prompt = 2;
}
enum Status {
PENDING = 0;
RUNNING = 1;
FINISHED = 2;
FAILED = 3;
}
message Output {
string id = 1;
Status status = 2;
string output = 3;
repeated float embed = 4;
}