diff --git a/CMakeLists.txt b/CMakeLists.txt index 3471e44f2..a4b4c8233 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,6 +73,7 @@ option(LLAMA_CLBLAST "llama: use CLBlast" option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_SERVER "llama: build server example" OFF) +option(LLAMA_BUILD_GRPC_SERVER "llama: build grpc server example" OFF) # # Build info header @@ -111,7 +112,7 @@ endif() # Compile flags # -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED true) set(CMAKE_C_STANDARD 11) set(CMAKE_C_STANDARD_REQUIRED true) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e4ce5aca7..7a7c3c2e6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -40,4 +40,7 @@ else() if(LLAMA_BUILD_SERVER) add_subdirectory(server) endif() + if(LLAMA_BUILD_GRPC_SERVER) + add_subdirectory(grpc-server) + endif() endif() diff --git a/examples/grpc-server/CMakeLists.txt b/examples/grpc-server/CMakeLists.txt new file mode 100644 index 000000000..51540ca4b --- /dev/null +++ b/examples/grpc-server/CMakeLists.txt @@ -0,0 +1,52 @@ +set(TARGET grpc-server) +set(_PROTOBUF_LIBPROTOBUF libprotobuf) +set(_REFLECTION grpc++_reflection) +find_package(absl REQUIRED) +find_package(Protobuf CONFIG REQUIRED PATHS ${MY_INSTALL_DIR}/lib) +include_directories($ENV{MY_INSTALL_DIR}/include) +find_package(gRPC CONFIG REQUIRED) +find_program(_PROTOBUF_PROTOC protoc) +set(_GRPC_GRPCPP grpc++) +find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +message(STATUS "Using protobuf ${Protobuf_VERSION} ${CMAKE_CURRENT_BINARY_DIR} $ENV{MY_INSTALL_DIR}/include") + + +# Proto file +get_filename_component(hw_proto "./message.proto" ABSOLUTE) +get_filename_component(hw_proto_path "${hw_proto}" PATH) + +# Generated sources +set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/message.pb.cc") +set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/message.pb.h") +set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/message.grpc.pb.cc") +set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/message.grpc.pb.h") + +add_custom_command( + OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" + COMMAND ${_PROTOBUF_PROTOC} + ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" + --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" + -I "${hw_proto_path}" + --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" + "${hw_proto}" + DEPENDS "${hw_proto}") + +# hw_grpc_proto +add_library(hw_grpc_proto + ${hw_grpc_srcs} + ${hw_grpc_hdrs} + ${hw_proto_srcs} + ${hw_proto_hdrs}) + +add_executable(${TARGET} grpc-server.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} hw_grpc_proto absl::flags + absl::flags_parse + gRPC::${_REFLECTION} + gRPC::${_GRPC_GRPCPP} + protobuf::${_PROTOBUF_LIBPROTOBUF}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/grpc-server/grpc-server.cpp b/examples/grpc-server/grpc-server.cpp new file mode 100644 index 000000000..35a467dca --- /dev/null +++ b/examples/grpc-server/grpc-server.cpp @@ -0,0 +1,348 @@ +/* + * + * Copyright 2015 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include "common.h" +#include "llama.h" + +#include + +#include "absl/strings/str_format.h" + +#ifdef BAZEL_BUILD +#include "examples/protos/message.grpc.pb.h" +#else +#include "message.grpc.pb.h" +#endif + +// ABSL_FLAG(uint16_t, port, 50051, "Server port for the service"); +// ABSL_FLAG(std::string, target, "localhost:50051", "Server address"); + +using grpc::CallbackServerContext; +using grpc::Server; +using grpc::ServerAsyncWriter; +using grpc::ServerBuilder; +using grpc::ServerCompletionQueue; +using grpc::ServerContext; +using grpc::ServerUnaryReactor; +using grpc::ServerWriteReactor; +using grpc::Status; +using robot::Job; +using robot::LlamaGoService; +using robot::Output; + +struct server_params +{ + std::string hostname = "127.0.0.1"; + int32_t port = 8080; +}; + +void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms) +{ + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); + fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); + fprintf(stderr, " --embedding enable embedding mode\n"); + fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); + if (llama_mlock_supported()) + { + fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); + } + if (llama_mmap_supported()) + { + fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); + } + fprintf(stderr, " -ngl N, --n-gpu-layers N\n"); + fprintf(stderr, " number of layers to store in VRAM\n"); + fprintf(stderr, " -m FNAME, --model FNAME\n"); + fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); + fprintf(stderr, " -host ip address to listen (default 127.0.0.1)\n"); + fprintf(stderr, " -port PORT port to listen (default 8080)\n"); + fprintf(stderr, "\n"); +} + +class LlamaServerContext +{ +public: + bool loaded; + + LlamaServerContext(gpt_params params_) : params(params_), threads(8) + { + ctx = llama_init_from_gpt_params(params); + if (ctx == NULL) + { + loaded = false; + fprintf(stderr, "%s: error: unable to load model\n", __func__); + } + else + { + loaded = true; + // determine newline token + llama_token_newline = ::llama_tokenize(ctx, "\n", false); + last_n_tokens.resize(params.n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + } + } + + std::vector embedding(std::string content) + { + content.insert(0, 1, ' '); + std::vector tokens = ::llama_tokenize(ctx, content, true); + if (tokens.size() > 0) + { + fprintf(stderr, "---3---,%p,%d", ctx, threads); + if (llama_eval(ctx, tokens.data(), tokens.size(), 0, 6)) + { + fprintf(stderr, "%s : failed to eval\n", __func__); + std::vector embeddings_; + return embeddings_; + } + } + const int n_embd = llama_n_embd(ctx); + const auto embeddings = llama_get_embeddings(ctx); + std::vector embeddings_(embeddings, embeddings + n_embd); + return embeddings_; + } + +private: + gpt_params params; + llama_context *ctx; + int threads; + + std::vector last_n_tokens; + std::vector llama_token_newline; +}; + +// Logic and data behind the server's behavior. +class LlamaServiceImpl final : public LlamaGoService::CallbackService +{ + + class Reactor : public grpc::ServerWriteReactor + { + public: + Reactor(CallbackServerContext *ctx, const Job *request) + : ctx_(ctx), request_(request) + { + Output response; + // StartWrite(&response_); + // StartWriteLast(&response_, WriteOptions()); + // ctx_->TryCancel(); + } + void OnDone() override { delete this; } + + private: + CallbackServerContext *const ctx_; + const Job *const request_; + }; + +public: + LlamaServiceImpl(LlamaServerContext *llama_) : llama(llama_) + { + fprintf(stderr, "%s : new impl\n", __func__); + } + + ServerWriteReactor *Answer( + CallbackServerContext *context, const Job *request) + { + fprintf(stderr, "%s : get answer\n", __func__); + return new Reactor(context, request); + } + + ServerUnaryReactor *Embed( + CallbackServerContext *context, const Job *request, Output *response) + { + fprintf(stderr, "%s : get embed %s\n", __func__, request->prompt().c_str()); + std::vector embeded = llama->embedding(request->prompt()); + fprintf(stderr, "0"); + fprintf(stderr, "%p", embeded.begin()); + *response->mutable_embed() = {embeded.begin(), embeded.end()}; + fprintf(stderr, "1"); + response->set_id(request->id()); + fprintf(stderr, "2"); + ServerUnaryReactor *reactor = context->DefaultReactor(); + fprintf(stderr, "3"); + reactor->Finish(Status::OK); + fprintf(stderr, "4"); + return reactor; + } + +private: + LlamaServerContext *llama; + int threads; +}; + +void RunServer(uint16_t port, LlamaServerContext *llama) +{ + std::string server_address = absl::StrFormat("0.0.0.0:%d", port); + LlamaServiceImpl service(llama); + + grpc::EnableDefaultHealthCheckService(true); + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params ¶ms) +{ + gpt_params default_params; + std::string arg; + bool invalid_param = false; + + for (int i = 1; i < argc; i++) + { + arg = argv[i]; + if (arg == "--port") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + sparams.port = std::stoi(argv[i]); + } + else if (arg == "--host") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + sparams.hostname = argv[i]; + } + else if (arg == "-s" || arg == "--seed") + { +#if defined(GGML_USE_CUBLAS) + fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n"); +#endif + if (++i >= argc) + { + invalid_param = true; + break; + } + params.seed = std::stoi(argv[i]); + } + else if (arg == "-m" || arg == "--model") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.model = argv[i]; + } + else if (arg == "--embedding") + { + params.embedding = true; + } + else if (arg == "-h" || arg == "--help") + { + server_print_usage(argc, argv, default_params); + exit(0); + } + else if (arg == "-c" || arg == "--ctx_size") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.n_ctx = std::stoi(argv[i]); + } + else if (arg == "--memory_f32") + { + params.memory_f16 = false; + } + else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + params.n_gpu_layers = std::stoi(argv[i]); + } + else + { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + server_print_usage(argc, argv, default_params); + exit(1); + } + } + + if (invalid_param) + { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + server_print_usage(argc, argv, default_params); + exit(1); + } + return true; +} + +int main(int argc, char **argv) +{ + + gpt_params params; + server_params sparams; + + llama_init_backend(); + + params.model = "ggml-model.bin"; + params.n_ctx = 512; + + sparams.port = 8080; + + if (gpt_params_parse(argc, argv, params) == false) + { + return 1; + } + + params.embedding = true; + + if (params.seed <= 0) + { + params.seed = time(NULL); + } + + fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); + + LlamaServerContext llama(params); + + // load the model + if (!llama.loaded) + { + return 1; + } + + RunServer(sparams.port, &llama); + return 0; +} diff --git a/examples/grpc-server/message.proto b/examples/grpc-server/message.proto new file mode 100755 index 000000000..c2ad80be8 --- /dev/null +++ b/examples/grpc-server/message.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package robot; + +option go_package = "./pkg/grpc"; + +service LlamaGoService { + rpc Answer(Job) returns (stream Output){} + rpc Embed(Job) returns (Output){} +} + +message Job { + string id = 1; + string prompt = 2; +} + +enum Status { + PENDING = 0; + RUNNING = 1; + FINISHED = 2; + FAILED = 3; +} + +message Output { + string id = 1; + Status status = 2; + string output = 3; + repeated float embed = 4; +} \ No newline at end of file