1) make gpt_params_parse can jump over some predefined unknown args so we can reuse the gpt_params_parse function
2) fixed the grpc server error
This commit is contained in:
parent
530eb57fe4
commit
0773028d52
3 changed files with 569 additions and 289 deletions
File diff suppressed because it is too large
Load diff
|
@ -74,6 +74,7 @@ struct gpt_params {
|
|||
};
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||
bool gpt_params_parse_with_extra_check(int argc, char **argv, gpt_params ¶ms, std::vector<std::string> *extra_args);
|
||||
|
||||
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
|
||||
|
||||
|
|
|
@ -32,9 +32,6 @@
|
|||
#include "message.grpc.pb.h"
|
||||
#endif
|
||||
|
||||
// ABSL_FLAG(uint16_t, port, 50051, "Server port for the service");
|
||||
// ABSL_FLAG(std::string, target, "localhost:50051", "Server address");
|
||||
|
||||
using grpc::CallbackServerContext;
|
||||
using grpc::Server;
|
||||
using grpc::ServerAsyncWriter;
|
||||
|
@ -149,6 +146,8 @@ public:
|
|||
params.antiprompt.clear();
|
||||
no_show_words.clear();
|
||||
num_tokens_predicted = 0;
|
||||
// embd.clear();
|
||||
// n_remain = 0;
|
||||
// generated_text = "";
|
||||
}
|
||||
|
||||
|
@ -250,6 +249,7 @@ public:
|
|||
{
|
||||
n_remain -= new_prompt_len;
|
||||
}
|
||||
fprintf(stderr, "embd_inp size %d,%d\n", embd_inp.size(), params.n_ctx);
|
||||
if ((int)embd_inp.size() > params.n_ctx - 4)
|
||||
{
|
||||
return false;
|
||||
|
@ -274,6 +274,7 @@ public:
|
|||
llama_token nextToken()
|
||||
{
|
||||
llama_token result = -1;
|
||||
// fprintf(stderr, "embed size %d,%d,%d\n", embd.size(), embd_inp.size(), n_consumed);
|
||||
if (embd.size() > 0)
|
||||
{
|
||||
if (n_past + (int)embd.size() > params.n_ctx)
|
||||
|
@ -482,6 +483,12 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
void printInfo()
|
||||
{
|
||||
fprintf(stderr, "embed size: %d\n", embd.size());
|
||||
fprintf(stderr, "embd_inp size: %d\n", embd_inp.size());
|
||||
}
|
||||
|
||||
private:
|
||||
gpt_params params;
|
||||
llama_context *ctx;
|
||||
|
@ -523,6 +530,11 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
|
|||
// fprintf(stderr, "on write done");
|
||||
NextWrite();
|
||||
}
|
||||
void OnCancel() override
|
||||
{
|
||||
fprintf(stderr, "on cancel");
|
||||
delete this;
|
||||
}
|
||||
|
||||
private:
|
||||
CallbackServerContext *const ctx_;
|
||||
|
@ -541,16 +553,20 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
|
|||
std::lock_guard<std::mutex> l(finish_mu_);
|
||||
auto result = llama_->doCompletion();
|
||||
fprintf(stderr, "%s", result.c_str());
|
||||
response.set_status(llama::Status::RUNNING);
|
||||
response.set_output(result);
|
||||
StartWrite(&response);
|
||||
}
|
||||
else
|
||||
{
|
||||
{
|
||||
response.set_status(llama::Status::FINISHED);
|
||||
std::lock_guard<std::mutex>
|
||||
l(finish_mu_);
|
||||
StartWriteLast(&response, grpc::WriteOptions());
|
||||
if (!finished_)
|
||||
{
|
||||
response.set_status(llama::Status::FINISHED);
|
||||
StartWriteLast(&response, grpc::WriteOptions());
|
||||
}
|
||||
}
|
||||
// If we use WriteLast, we shouldn't wait before attempting Finish
|
||||
FinishOnce(Status::OK);
|
||||
|
@ -571,13 +587,12 @@ class LlamaServiceImpl final : public LlamaGoService::CallbackService
|
|||
public:
|
||||
LlamaServiceImpl(LlamaServerContext *llama_) : llama(llama_)
|
||||
{
|
||||
fprintf(stderr, "%s : new impl\n", __func__);
|
||||
}
|
||||
|
||||
ServerWriteReactor<Output> *Answer(
|
||||
CallbackServerContext *context, const Job *request)
|
||||
{
|
||||
fprintf(stderr, "%s : get answer\n", __func__);
|
||||
fprintf(stderr, "%s : new answer request: %s\n", __func__, request->prompt().c_str());
|
||||
llama->rewind();
|
||||
// std::vector<float> embeded = llama->complete(request->prompt());
|
||||
Reactor *reactor = new Reactor(context, llama, request);
|
||||
|
@ -625,6 +640,12 @@ void RunServer(uint16_t port, LlamaServerContext *llama)
|
|||
server->Wait();
|
||||
}
|
||||
|
||||
// auto server_params_parse(server_params &sparams)
|
||||
// {
|
||||
|
||||
// return set_extra_params;
|
||||
// }
|
||||
|
||||
bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_params ¶ms)
|
||||
{
|
||||
gpt_params default_params;
|
||||
|
@ -652,64 +673,6 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
}
|
||||
sparams.hostname = argv[i];
|
||||
}
|
||||
else if (arg == "-s" || arg == "--seed")
|
||||
{
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
fprintf(stderr, "WARNING: when using cuBLAS generation results are NOT guaranteed to be reproducible.\n");
|
||||
#endif
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.seed = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "-m" || arg == "--model")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.model = argv[i];
|
||||
}
|
||||
else if (arg == "--embedding")
|
||||
{
|
||||
params.embedding = true;
|
||||
}
|
||||
else if (arg == "-h" || arg == "--help")
|
||||
{
|
||||
server_print_usage(argc, argv, default_params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-c" || arg == "--ctx_size")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_ctx = std::stoi(argv[i]);
|
||||
}
|
||||
else if (arg == "--memory_f32")
|
||||
{
|
||||
params.memory_f16 = false;
|
||||
}
|
||||
else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.n_gpu_layers = std::stoi(argv[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
server_print_usage(argc, argv, default_params);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (invalid_param)
|
||||
|
@ -731,10 +694,18 @@ int main(int argc, char **argv)
|
|||
|
||||
params.model = "ggml-model.bin";
|
||||
params.n_ctx = 512;
|
||||
// params.embedding = true;
|
||||
|
||||
sparams.port = 8080;
|
||||
|
||||
if (gpt_params_parse(argc, argv, params) == false)
|
||||
std::vector<std::string> extra_args = {"--server", "--port"};
|
||||
|
||||
if (server_params_parse(argc, argv, sparams, params) == false)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (gpt_params_parse_with_extra_check(argc, argv, params, &extra_args) == false)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue