1) make gpt_params_parse can jump over some predefined unknown args so we can reuse the gpt_params_parse function

2) fixed the grpc server error
This commit is contained in:
Liu Ming 2023-05-29 14:07:13 +08:00
parent 530eb57fe4
commit 0773028d52
3 changed files with 569 additions and 289 deletions

File diff suppressed because it is too large Load diff

View file

@ -74,6 +74,7 @@ struct gpt_params {
};
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
bool gpt_params_parse_with_extra_check(int argc, char **argv, gpt_params &params, std::vector<std::string> *extra_args);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);

View file

@ -32,9 +32,6 @@
#include "message.grpc.pb.h"
#endif
// ABSL_FLAG(uint16_t, port, 50051, "Server port for the service");
// ABSL_FLAG(std::string, target, "localhost:50051", "Server address");
using grpc::CallbackServerContext;
using grpc::Server;
using grpc::ServerAsyncWriter;
@ -149,6 +146,8 @@ public:
params.antiprompt.clear();
no_show_words.clear();
num_tokens_predicted = 0;
// embd.clear();
// n_remain = 0;
// generated_text = "";
}
@ -250,6 +249,7 @@ public:
{
n_remain -= new_prompt_len;
}
fprintf(stderr, "embd_inp size %d,%d\n", embd_inp.size(), params.n_ctx);
if ((int)embd_inp.size() > params.n_ctx - 4)
{
return false;
@ -274,6 +274,7 @@ public:
llama_token nextToken()
{
llama_token result = -1;
// fprintf(stderr, "embed size %d,%d,%d\n", embd.size(), embd_inp.size(), n_consumed);
if (embd.size() > 0)
{
if (n_past + (int)embd.size() > params.n_ctx)
@ -482,6 +483,12 @@ public:
}
}
void printInfo()
{
fprintf(stderr, "embed size: %d\n", embd.size());
fprintf(stderr, "embd_inp size: %d\n", embd_inp.size());
}
private:
gpt_params params;
llama_context *ctx;
@ -523,6 +530,11 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
// fprintf(stderr, "on write done");
NextWrite();
}
void OnCancel() override
{
fprintf(stderr, "on cancel");
delete this;
}
private:
CallbackServerContext *const ctx_;
@ -541,17 +553,21 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
std::lock_guard<std::mutex> l(finish_mu_);
auto result = llama_->doCompletion();
fprintf(stderr, "%s", result.c_str());
response.set_status(llama::Status::RUNNING);
response.set_output(result);
StartWrite(&response);
}
else
{
{
response.set_status(llama::Status::FINISHED);
std::lock_guard<std::mutex>
l(finish_mu_);
if (!finished_)
{
response.set_status(llama::Status::FINISHED);
StartWriteLast(&response, grpc::WriteOptions());
}
}
// If we use WriteLast, we shouldn't wait before attempting Finish
FinishOnce(Status::OK);
}
@ -571,13 +587,12 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
public:
LlamaServiceImpl(LlamaServerContext *llama_) : llama(llama_)
{
fprintf(stderr, "%s : new impl\n", __func__);
}
ServerWriteReactor<Output> *Answer(
CallbackServerContext *context, const Job *request)
{
fprintf(stderr, "%s : get answer\n", __func__);
fprintf(stderr, "%s : new answer request: %s\n", __func__, request->prompt().c_str());
llama->rewind();
// std::vector<float> embeded = llama->complete(request->prompt());
Reactor *reactor = new Reactor(context, llama, request);
@ -625,6 +640,12 @@ void RunServer(uint16_t port, LlamaServerContext *llama)
server->Wait();
}
// auto server_params_parse(server_params &sparams)
// {
// return set_extra_params;
// }
bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params)
{
gpt_params default_params;
@ -652,64 +673,6 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
}
sparams.hostname = argv[i];
}
else if (arg == "-s" || arg == "--seed")
{
#if defined(GGML_USE_CUBLAS)
fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
#endif
if (++i >= argc)
{
invalid_param = true;
break;
}
params.seed = std::stoi(argv[i]);
}
else if (arg == "-m" || arg == "--model")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.model = argv[i];
}
else if (arg == "--embedding")
{
params.embedding = true;
}
else if (arg == "-h" || arg == "--help")
{
server_print_usage(argc, argv, default_params);
exit(0);
}
else if (arg == "-c" || arg == "--ctx_size")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_ctx = std::stoi(argv[i]);
}
else if (arg == "--memory_f32")
{
params.memory_f16 = false;
}
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.n_gpu_layers = std::stoi(argv[i]);
}
else
{
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
server_print_usage(argc, argv, default_params);
exit(1);
}
}
if (invalid_param)
@ -731,10 +694,18 @@ int main(int argc, char **argv)
params.model = "ggml-model.bin";
params.n_ctx = 512;
// params.embedding = true;
sparams.port = 8080;
if (gpt_params_parse(argc, argv, params) == false)
std::vector<std::string> extra_args = {"--server", "--port"};
if (server_params_parse(argc, argv, sparams, params) == false)
{
return 1;
}
if (gpt_params_parse_with_extra_check(argc, argv, params, &extra_args) == false)
{
return 1;
}