rpc : enable async operations

Start a dedicated backend thread in the rpc-server and use message
passing interface for submitting work to it. This will enable backend
async operations and cross-server communication.
This commit is contained in:
Radoslav Gerganov 2024-06-13 09:57:24 +03:00
parent 172c825684
commit b30565e0c8

View file

@ -5,8 +5,11 @@
#include <cinttypes> #include <cinttypes>
#include <string> #include <string>
#include <vector> #include <vector>
#include <queue>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <thread>
#include <condition_variable>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#ifdef _WIN32 #ifdef _WIN32
@ -17,6 +20,7 @@
# include <windows.h> # include <windows.h>
# include <winsock2.h> # include <winsock2.h>
#else #else
# include <signal.h>
# include <arpa/inet.h> # include <arpa/inet.h>
# include <sys/socket.h> # include <sys/socket.h>
# include <sys/types.h> # include <sys/types.h>
@ -89,6 +93,7 @@ enum rpc_cmd {
COPY_TENSOR, COPY_TENSOR,
GRAPH_COMPUTE, GRAPH_COMPUTE,
GET_DEVICE_MEMORY, GET_DEVICE_MEMORY,
FREE_ALL_BUFFERS,
}; };
// RPC data structures // RPC data structures
@ -736,6 +741,48 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
// RPC server-side implementation // RPC server-side implementation
template <typename T>
class message_queue {
std::queue<T> queue;
std::mutex mutex;
std::condition_variable cvar;
public:
message_queue() {}
void push(const T &value) {
std::unique_lock<std::mutex> lock(mutex);
queue.push(value);
lock.unlock();
cvar.notify_all();
}
void pop(T* out) {
std::unique_lock<std::mutex> lock(mutex);
cvar.wait(lock, [this] { return queue.size() > 0; });
*out = queue.front();
queue.pop();
}
};
struct rpc_response {
std::vector<uint8_t> output;
bool status;
};
using rpc_response_ptr = std::shared_ptr<rpc_response>;
using response_queue = message_queue<rpc_response_ptr>;
using response_queue_ptr = std::shared_ptr<response_queue>;
struct rpc_request {
rpc_cmd cmd;
std::vector<uint8_t> input;
response_queue_ptr response_queue;
};
using rpc_request_ptr = std::shared_ptr<rpc_request>;
using request_queue = message_queue<rpc_request_ptr>;
using request_queue_ptr = std::shared_ptr<request_queue>;
class rpc_server { class rpc_server {
public: public:
rpc_server(ggml_backend_t backend) : backend(backend) {} rpc_server(ggml_backend_t backend) : backend(backend) {}
@ -752,6 +799,7 @@ public:
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output); bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
void free_all_buffers();
private: private:
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
ggml_tensor * create_node(uint64_t id, ggml_tensor * create_node(uint64_t id,
@ -1046,76 +1094,122 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
return true; return true;
} }
rpc_server::~rpc_server() { void rpc_server::free_all_buffers() {
for (auto buffer : buffers) { for (auto buffer : buffers) {
ggml_backend_buffer_free(buffer); ggml_backend_buffer_free(buffer);
} }
buffers.clear();
} }
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) { rpc_server::~rpc_server() {
free_all_buffers();
}
static void process_requests(ggml_backend_t backend, request_queue_ptr requestq) {
rpc_server server(backend); rpc_server server(backend);
while (true) {
rpc_request_ptr request;
requestq->pop(&request);
rpc_response_ptr response = std::make_shared<rpc_response>();
bool ok = true;
switch (request->cmd) {
case ALLOC_BUFFER: {
ok = server.alloc_buffer(request->input, response->output);
break;
}
case GET_ALIGNMENT: {
server.get_alignment(response->output);
break;
}
case GET_MAX_SIZE: {
server.get_max_size(response->output);
break;
}
case BUFFER_GET_BASE: {
ok = server.buffer_get_base(request->input, response->output);
break;
}
case FREE_BUFFER: {
ok = server.free_buffer(request->input);
break;
}
case BUFFER_CLEAR: {
ok = server.buffer_clear(request->input);
break;
}
case SET_TENSOR: {
ok = server.set_tensor(request->input);
break;
}
case GET_TENSOR: {
ok = server.get_tensor(request->input, response->output);
break;
}
case COPY_TENSOR: {
ok = server.copy_tensor(request->input, response->output);
break;
}
case GRAPH_COMPUTE: {
ok = server.graph_compute(request->input, response->output);
break;
}
case GET_DEVICE_MEMORY: {
break;
}
case FREE_ALL_BUFFERS: {
server.free_all_buffers();
continue;
}
default: {
fprintf(stderr, "Unknown command: %d\n", request->cmd);
ok = false;
}
}
response->status = ok;
request->response_queue->push(response);
}
}
static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
auto responseq = std::make_shared<response_queue>();
while (true) { while (true) {
uint8_t cmd; uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) { if (!recv_data(sockfd, &cmd, 1)) {
break; break;
} }
std::vector<uint8_t> input; auto request = std::make_shared<rpc_request>();
std::vector<uint8_t> output; request->cmd = (rpc_cmd)cmd;
request->response_queue = responseq;
uint64_t input_size; uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) { if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
break; break;
} }
input.resize(input_size); request->input.resize(input_size);
if (!recv_data(sockfd, input.data(), input_size)) { if (!recv_data(sockfd, request->input.data(), input_size)) {
break; break;
} }
bool ok = true; bool ok = true;
auto response = std::make_shared<rpc_response>();
switch (cmd) { switch (cmd) {
case ALLOC_BUFFER: { case ALLOC_BUFFER:
ok = server.alloc_buffer(input, output); case GET_ALIGNMENT:
break; case GET_MAX_SIZE:
} case BUFFER_GET_BASE:
case GET_ALIGNMENT: { case FREE_BUFFER:
server.get_alignment(output); case BUFFER_CLEAR:
break; case SET_TENSOR:
} case GET_TENSOR:
case GET_MAX_SIZE: { case COPY_TENSOR:
server.get_max_size(output);
break;
}
case BUFFER_GET_BASE: {
ok = server.buffer_get_base(input, output);
break;
}
case FREE_BUFFER: {
ok = server.free_buffer(input);
break;
}
case BUFFER_CLEAR: {
ok = server.buffer_clear(input);
break;
}
case SET_TENSOR: {
ok = server.set_tensor(input);
break;
}
case GET_TENSOR: {
ok = server.get_tensor(input, output);
break;
}
case COPY_TENSOR: {
ok = server.copy_tensor(input, output);
break;
}
case GRAPH_COMPUTE: { case GRAPH_COMPUTE: {
ok = server.graph_compute(input, output); requestq->push(request);
responseq->pop(&response);
break; break;
} }
case GET_DEVICE_MEMORY: { case GET_DEVICE_MEMORY: {
// output serialization format: | free (8 bytes) | total (8 bytes) | // output serialization format: | free (8 bytes) | total (8 bytes) |
output.resize(2*sizeof(uint64_t), 0); response->output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &free_mem, sizeof(free_mem)); memcpy(response->output.data(), &free_mem, sizeof(free_mem));
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); memcpy(response->output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
break; break;
} }
default: { default: {
@ -1126,17 +1220,29 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
if (!ok) { if (!ok) {
break; break;
} }
uint64_t output_size = output.size(); uint64_t output_size = response->output.size();
if (!send_data(sockfd, &output_size, sizeof(output_size))) { if (!send_data(sockfd, &output_size, sizeof(output_size))) {
break; break;
} }
if (!send_data(sockfd, output.data(), output_size)) { if (!send_data(sockfd, response->output.data(), output_size)) {
break; break;
} }
} }
auto request = std::make_shared<rpc_request>();
request->cmd = FREE_ALL_BUFFERS;
requestq->push(request);
} }
void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) { void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
#ifndef _WIN32
// prevent SIGPIPE when writing to closed socket
signal(SIGPIPE, SIG_IGN);
#endif
auto requestq = std::make_shared<request_queue>();
std::thread backend_thread = std::thread([=] {
process_requests(backend, requestq);
});
std::string host; std::string host;
int port; int port;
if (!parse_endpoint(endpoint, host, port)) { if (!parse_endpoint(endpoint, host, port)) {
@ -1164,7 +1270,7 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
return; return;
} }
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); rpc_serve_client(requestq, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n"); printf("Client connection closed\n");
} }
#ifdef _WIN32 #ifdef _WIN32