Implement TCP server mode.
This new mode works by first loading the model then listening for TCP connections on a port. When a connection is received, arguments will be parsed using a simple protocol: - First the number of arguments will be read followed by a newline character. - Then each argument will be read, separated by the 0 byte. - With this we build an argument vector, similar to what is passed to the program entry point. We pass this to gpt_params_parse. Finally `llama_main` will be executed with the input/output streams connected to the socket. Signed-off-by: Thiago Padilha <thiago@padilha.cc>
This commit is contained in:
parent
9ed33b37de
commit
b6fdbee3de
9 changed files with 335 additions and 2 deletions
|
@ -112,6 +112,10 @@ add_executable(llama
|
||||||
llama.cpp
|
llama.cpp
|
||||||
utils.h)
|
utils.h)
|
||||||
|
|
||||||
|
if(NOT WIN32)
|
||||||
|
target_sources(llama PRIVATE tcp_server.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_executable(quantize
|
add_executable(quantize
|
||||||
quantize.cpp
|
quantize.cpp
|
||||||
utils.cpp
|
utils.cpp
|
||||||
|
|
7
Makefile
7
Makefile
|
@ -191,11 +191,14 @@ utils.o: utils.cpp utils.h
|
||||||
llama.o: llama.cpp llama.h
|
llama.o: llama.cpp llama.h
|
||||||
$(CXX) $(CXXFLAGS) -c llama.cpp -o llama.o
|
$(CXX) $(CXXFLAGS) -c llama.cpp -o llama.o
|
||||||
|
|
||||||
|
tcp_server.o: tcp_server.cpp tcp_server.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c tcp_server.cpp -o tcp_server.o
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -f *.o main quantize
|
rm -f *.o main quantize
|
||||||
|
|
||||||
main: main.cpp ggml.o utils.o llama.o
|
main: main.cpp ggml.o utils.o llama.o tcp_server.o
|
||||||
$(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o llama.o -o main $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o llama.o tcp_server.o -o main $(LDFLAGS)
|
||||||
./main -h
|
./main -h
|
||||||
|
|
||||||
quantize: quantize.cpp ggml.o utils.o
|
quantize: quantize.cpp ggml.o utils.o
|
||||||
|
|
45
chat_tcp_client.sh
Executable file
45
chat_tcp_client.sh
Executable file
|
@ -0,0 +1,45 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
PORT=${PORT:-8080}
|
||||||
|
PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
||||||
|
|
||||||
|
User:Hello, Bob.
|
||||||
|
Bob:Hello. How may I help you today?
|
||||||
|
User:Please tell me the largest city in Europe.
|
||||||
|
Bob:Sure. The largest city in Europe is Moscow, the capital of Russia.
|
||||||
|
User:"}"
|
||||||
|
RPROMPT="${RPROMPT:-"User:"}"
|
||||||
|
N_PREDICT="${N_PREDICT:-"4096"}"
|
||||||
|
REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}"
|
||||||
|
N_THREADS="${N_THREADS:-"4"}"
|
||||||
|
|
||||||
|
# Open connection to the chat server
|
||||||
|
exec 3<>/dev/tcp/127.0.0.1/${PORT}
|
||||||
|
|
||||||
|
# Pass the arguments. The protocol is really simple:
|
||||||
|
# 1. Pass the number of arguments followed by a linefeed
|
||||||
|
# 2. Pass the arguments, with each being followed by "0"
|
||||||
|
(
|
||||||
|
echo -en "12\n"
|
||||||
|
echo -en "-t\x00"
|
||||||
|
echo -en "$N_THREADS\x00"
|
||||||
|
echo -en "-n\x00"
|
||||||
|
echo -en "$N_PREDICT\x00"
|
||||||
|
echo -en "--repeat_penalty\x00"
|
||||||
|
echo -en "$REPEAT_PENALTY\x00"
|
||||||
|
echo -en "--color\x00"
|
||||||
|
echo -en "-i\x00"
|
||||||
|
echo -en "-r\x00"
|
||||||
|
echo -en "$RPROMPT\x00"
|
||||||
|
echo -en "-p\x00"
|
||||||
|
echo -en "$PROMPT\x00"
|
||||||
|
) >&3
|
||||||
|
|
||||||
|
trap exit TERM
|
||||||
|
|
||||||
|
# When we have passed the arguments, start printing socket data to the screen.
|
||||||
|
# This is done in a background job because we also want to send data when
|
||||||
|
# running in interactive mode.
|
||||||
|
cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" &
|
||||||
|
cat >&3
|
||||||
|
wait
|
6
chat_tcp_server.sh
Executable file
6
chat_tcp_server.sh
Executable file
|
@ -0,0 +1,6 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
PORT=${PORT:-8080}
|
||||||
|
MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin}
|
||||||
|
|
||||||
|
./main -l ${PORT} -m $MODEL
|
7
main.cpp
7
main.cpp
|
@ -1,6 +1,7 @@
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "tcp_server.h"
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
@ -65,5 +66,11 @@ int main(int argc, char ** argv) {
|
||||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
if (params.listen_port != "") {
|
||||||
|
return listen_tcp(params, vocab, model, t_main_start_us, t_load_us);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
return llama_main(params, vocab, model, t_main_start_us, t_load_us, std::cin, stdout, stderr);
|
return llama_main(params, vocab, model, t_main_start_us, t_load_us, std::cin, stdout, stderr);
|
||||||
}
|
}
|
||||||
|
|
245
tcp_server.cpp
Normal file
245
tcp_server.cpp
Normal file
|
@ -0,0 +1,245 @@
|
||||||
|
#include "tcp_server.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include <stdarg.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <errno.h>
|
||||||
|
|
||||||
|
#include <signal.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <sys/wait.h>
|
||||||
|
|
||||||
|
#include <sys/types.h>
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <netdb.h>
|
||||||
|
|
||||||
|
class PosixStream : public std::istream {
|
||||||
|
public:
|
||||||
|
PosixStream(int fd) : std::istream(&buf), buf(fd) {}
|
||||||
|
~PosixStream() { close(buf.get_fd()); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
class PosixStreamBuf : public std::streambuf {
|
||||||
|
public:
|
||||||
|
PosixStreamBuf(int fd) : fd(fd) {}
|
||||||
|
int get_fd() const { return fd; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual int_type underflow() {
|
||||||
|
if (gptr() < egptr()) {
|
||||||
|
return traits_type::to_int_type(*gptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
ssize_t num_read = ::read(fd, buffer, BUFFER_SIZE);
|
||||||
|
if (num_read <= 0) {
|
||||||
|
return traits_type::eof();
|
||||||
|
}
|
||||||
|
|
||||||
|
setg(buffer, buffer, buffer + num_read);
|
||||||
|
return traits_type::to_int_type(*gptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static const int BUFFER_SIZE = 1024;
|
||||||
|
int fd;
|
||||||
|
char buffer[BUFFER_SIZE];
|
||||||
|
};
|
||||||
|
|
||||||
|
PosixStreamBuf buf;
|
||||||
|
};
|
||||||
|
|
||||||
|
void die(const char *msg, ...)
|
||||||
|
{
|
||||||
|
va_list ap;
|
||||||
|
|
||||||
|
va_start(ap, msg);
|
||||||
|
vfprintf(stderr, msg, ap);
|
||||||
|
va_end(ap);
|
||||||
|
fputc('\n', stderr);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) {
|
||||||
|
bool done = false;
|
||||||
|
uint8_t *buf = *param_buf;
|
||||||
|
size_t bufsize = *param_buf_size;
|
||||||
|
size_t bufpos = 0;
|
||||||
|
while (!done) {
|
||||||
|
if (bufpos == bufsize) {
|
||||||
|
bufsize += 1024;
|
||||||
|
buf = (uint8_t *)realloc(buf, bufsize);
|
||||||
|
if (!buf) {
|
||||||
|
die("failed to allocate memory");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int c = fgetc(instream);
|
||||||
|
if (c == EOF) {
|
||||||
|
die("unexpected EOF client socket");
|
||||||
|
}
|
||||||
|
buf[bufpos++] = (uint8_t)c;
|
||||||
|
if (c == 0) {
|
||||||
|
// done reading argument
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*param_buf = buf;
|
||||||
|
*param_buf_size = bufsize;
|
||||||
|
return strdup((char *)buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int read_arguments(int argc, char **argv, FILE *instream) {
|
||||||
|
int i = 1;
|
||||||
|
size_t param_buf_size = 0;
|
||||||
|
uint8_t *param_buf = nullptr;
|
||||||
|
|
||||||
|
for (i = 1; i < argc; i++) {
|
||||||
|
argv[i] = read_argument(¶m_buf, ¶m_buf_size, instream);
|
||||||
|
}
|
||||||
|
|
||||||
|
free(param_buf);
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int serve_model(
|
||||||
|
gpt_params params,
|
||||||
|
gpt_vocab vocab,
|
||||||
|
llama_model model,
|
||||||
|
int64_t t_load_us,
|
||||||
|
int64_t t_main_start_us,
|
||||||
|
int sock_fd)
|
||||||
|
{
|
||||||
|
char *response_data;
|
||||||
|
int argc;
|
||||||
|
char **argv;
|
||||||
|
FILE *instream = fdopen(sock_fd, "r");
|
||||||
|
FILE *outstream = fdopen(sock_fd, "w");
|
||||||
|
setvbuf(instream, NULL, _IONBF, 0);
|
||||||
|
|
||||||
|
// start by reading the parameter count
|
||||||
|
if (fscanf(instream, "%d\n", &argc) != 1) {
|
||||||
|
fprintf(outstream, "Error: First line must be character count\n");
|
||||||
|
fflush(outstream);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
argc += 1; // add one extra argument to emulate the program command line
|
||||||
|
argv = (char **)malloc(argc * sizeof *argv);
|
||||||
|
argv[0] = nullptr;
|
||||||
|
if (read_arguments(argc, argv, instream) != argc) {
|
||||||
|
fprintf(outstream, "Error: Failed to read arguments\n");
|
||||||
|
fflush(outstream);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gpt_params_parse(argc, argv, params) == false) {
|
||||||
|
fprintf(outstream, "Error: Failed to parse parameters\n");
|
||||||
|
fflush(outstream);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 1; i < argc; i++) {
|
||||||
|
free(argv[i]);
|
||||||
|
}
|
||||||
|
free(argv);
|
||||||
|
|
||||||
|
PosixStream tcp_is(sock_fd);
|
||||||
|
|
||||||
|
return llama_main(params, vocab, model, t_load_us, t_main_start_us, tcp_is, outstream, outstream);
|
||||||
|
}
|
||||||
|
|
||||||
|
int listen_tcp(
|
||||||
|
gpt_params params,
|
||||||
|
gpt_vocab vocab,
|
||||||
|
llama_model model,
|
||||||
|
int64_t t_main_start_us,
|
||||||
|
int64_t t_load_us) {
|
||||||
|
int listen_fd;
|
||||||
|
int status;
|
||||||
|
pid_t child;
|
||||||
|
struct addrinfo hints;
|
||||||
|
struct addrinfo *servinfo, *p;
|
||||||
|
int yes = 1;
|
||||||
|
|
||||||
|
memset(&hints, 0, sizeof hints);
|
||||||
|
hints.ai_family = AF_INET;
|
||||||
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
hints.ai_flags = AI_PASSIVE;
|
||||||
|
|
||||||
|
// This should only ever listen on a loopback address. Access from outside
|
||||||
|
// should be proxied via nginx or similar software
|
||||||
|
status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo);
|
||||||
|
if (status) {
|
||||||
|
die("getaddrinfo error: %s", gai_strerror(status));
|
||||||
|
}
|
||||||
|
|
||||||
|
// bind to the first addrinfo we can from the getaddrinfo results
|
||||||
|
for (p = servinfo; p != NULL; p = p->ai_next) {
|
||||||
|
listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
|
||||||
|
if (listen_fd == -1) {
|
||||||
|
perror("server: socket");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes)) {
|
||||||
|
die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
close(listen_fd);
|
||||||
|
perror("server: bind");
|
||||||
|
}
|
||||||
|
|
||||||
|
freeaddrinfo(servinfo);
|
||||||
|
|
||||||
|
if (p == NULL) {
|
||||||
|
die("failed to bind: %s", strerror(errno));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (listen(listen_fd, 20)) {
|
||||||
|
die("listen error: %s", strerror(errno));
|
||||||
|
}
|
||||||
|
// Don't track child processes, so ignore SIGCHLD to prevent zombies
|
||||||
|
signal(SIGCHLD, SIG_IGN);
|
||||||
|
|
||||||
|
for (;;) {
|
||||||
|
struct sockaddr_in client_addr = {0};
|
||||||
|
socklen_t client_addr_len = 0;
|
||||||
|
|
||||||
|
int sock_fd = accept(listen_fd,
|
||||||
|
(struct sockaddr *)&client_addr,
|
||||||
|
&client_addr_len);
|
||||||
|
if (sock_fd < 0) {
|
||||||
|
fprintf(stderr, "accept error: %s\n", strerror(errno));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
child = fork();
|
||||||
|
if (child == 0) {
|
||||||
|
// close the listen_fd since we won't use it in the child
|
||||||
|
close(listen_fd);
|
||||||
|
int ret = serve_model(params, vocab, model, t_main_start_us, t_load_us, sock_fd);
|
||||||
|
close(sock_fd);
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
// close the client since we won't use it in the server
|
||||||
|
close(sock_fd);
|
||||||
|
sock_fd = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(listen_fd);
|
||||||
|
|
||||||
|
// ignore SIGTERM since we'll send it to the group
|
||||||
|
signal(SIGTERM, SIG_IGN);
|
||||||
|
// tell children to exit
|
||||||
|
kill(0, SIGTERM);
|
||||||
|
// wait for children to terminate
|
||||||
|
wait(&status);
|
||||||
|
return 0;
|
||||||
|
}
|
11
tcp_server.h
Normal file
11
tcp_server.h
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "utils.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
int listen_tcp(
|
||||||
|
gpt_params params,
|
||||||
|
gpt_vocab vocab,
|
||||||
|
llama_model model,
|
||||||
|
int64_t t_main_start_us,
|
||||||
|
int64_t t_load_us);
|
|
@ -74,6 +74,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
params.antiprompt.push_back(argv[++i]);
|
params.antiprompt.push_back(argv[++i]);
|
||||||
} else if (arg == "--ignore-eos") {
|
} else if (arg == "--ignore-eos") {
|
||||||
params.ignore_eos = true;
|
params.ignore_eos = true;
|
||||||
|
#ifndef _WIN32
|
||||||
|
} else if (arg == "-l" || arg == "--listen") {
|
||||||
|
params.listen_port = argv[++i];
|
||||||
|
#endif
|
||||||
} else if (arg == "-h" || arg == "--help") {
|
} else if (arg == "-h" || arg == "--help") {
|
||||||
gpt_print_usage(argc, argv, params);
|
gpt_print_usage(argc, argv, params);
|
||||||
exit(0);
|
exit(0);
|
||||||
|
@ -119,6 +123,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||||
|
#ifndef _WIN32
|
||||||
|
fprintf(stderr, " -l PORT, --listen PORT\n");
|
||||||
|
fprintf(stderr, " Run in TCP mode, listening on PORT\n");
|
||||||
|
#endif
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
4
utils.h
4
utils.h
|
@ -40,6 +40,10 @@ struct gpt_params {
|
||||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||||
bool ignore_eos = false; // do not stop generating after eos
|
bool ignore_eos = false; // do not stop generating after eos
|
||||||
|
|
||||||
|
#ifndef _WIN32
|
||||||
|
std::string listen_port = ""; // TCP port for when running in server mode
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue