Merge b30565e0c8
into 8854044561
This commit is contained in:
commit
70ef817068
1 changed files with 155 additions and 49 deletions
204
ggml-rpc.cpp
204
ggml-rpc.cpp
|
@ -5,8 +5,11 @@
|
|||
#include <cinttypes>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <condition_variable>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#ifdef _WIN32
|
||||
|
@ -17,6 +20,7 @@
|
|||
# include <windows.h>
|
||||
# include <winsock2.h>
|
||||
#else
|
||||
# include <signal.h>
|
||||
# include <arpa/inet.h>
|
||||
# include <sys/socket.h>
|
||||
# include <sys/types.h>
|
||||
|
@ -93,6 +97,7 @@ enum rpc_cmd {
|
|||
COPY_TENSOR,
|
||||
GRAPH_COMPUTE,
|
||||
GET_DEVICE_MEMORY,
|
||||
FREE_ALL_BUFFERS,
|
||||
};
|
||||
|
||||
// RPC data structures
|
||||
|
@ -739,6 +744,48 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
|
|||
|
||||
// RPC server-side implementation
|
||||
|
||||
template <typename T>
|
||||
class message_queue {
|
||||
std::queue<T> queue;
|
||||
std::mutex mutex;
|
||||
std::condition_variable cvar;
|
||||
|
||||
public:
|
||||
message_queue() {}
|
||||
|
||||
void push(const T &value) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
queue.push(value);
|
||||
lock.unlock();
|
||||
cvar.notify_all();
|
||||
}
|
||||
|
||||
void pop(T* out) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
cvar.wait(lock, [this] { return queue.size() > 0; });
|
||||
*out = queue.front();
|
||||
queue.pop();
|
||||
}
|
||||
};
|
||||
|
||||
struct rpc_response {
|
||||
std::vector<uint8_t> output;
|
||||
bool status;
|
||||
};
|
||||
|
||||
using rpc_response_ptr = std::shared_ptr<rpc_response>;
|
||||
using response_queue = message_queue<rpc_response_ptr>;
|
||||
using response_queue_ptr = std::shared_ptr<response_queue>;
|
||||
|
||||
struct rpc_request {
|
||||
rpc_cmd cmd;
|
||||
std::vector<uint8_t> input;
|
||||
response_queue_ptr response_queue;
|
||||
};
|
||||
using rpc_request_ptr = std::shared_ptr<rpc_request>;
|
||||
using request_queue = message_queue<rpc_request_ptr>;
|
||||
using request_queue_ptr = std::shared_ptr<request_queue>;
|
||||
|
||||
class rpc_server {
|
||||
public:
|
||||
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
||||
|
@ -755,6 +802,7 @@ public:
|
|||
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||
|
||||
void free_all_buffers();
|
||||
private:
|
||||
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
||||
ggml_tensor * create_node(uint64_t id,
|
||||
|
@ -1051,76 +1099,122 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
|
|||
return true;
|
||||
}
|
||||
|
||||
rpc_server::~rpc_server() {
|
||||
void rpc_server::free_all_buffers() {
|
||||
for (auto buffer : buffers) {
|
||||
ggml_backend_buffer_free(buffer);
|
||||
}
|
||||
buffers.clear();
|
||||
}
|
||||
|
||||
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
||||
rpc_server::~rpc_server() {
|
||||
free_all_buffers();
|
||||
}
|
||||
|
||||
static void process_requests(ggml_backend_t backend, request_queue_ptr requestq) {
|
||||
rpc_server server(backend);
|
||||
while (true) {
|
||||
rpc_request_ptr request;
|
||||
requestq->pop(&request);
|
||||
rpc_response_ptr response = std::make_shared<rpc_response>();
|
||||
bool ok = true;
|
||||
switch (request->cmd) {
|
||||
case ALLOC_BUFFER: {
|
||||
ok = server.alloc_buffer(request->input, response->output);
|
||||
break;
|
||||
}
|
||||
case GET_ALIGNMENT: {
|
||||
server.get_alignment(response->output);
|
||||
break;
|
||||
}
|
||||
case GET_MAX_SIZE: {
|
||||
server.get_max_size(response->output);
|
||||
break;
|
||||
}
|
||||
case BUFFER_GET_BASE: {
|
||||
ok = server.buffer_get_base(request->input, response->output);
|
||||
break;
|
||||
}
|
||||
case FREE_BUFFER: {
|
||||
ok = server.free_buffer(request->input);
|
||||
break;
|
||||
}
|
||||
case BUFFER_CLEAR: {
|
||||
ok = server.buffer_clear(request->input);
|
||||
break;
|
||||
}
|
||||
case SET_TENSOR: {
|
||||
ok = server.set_tensor(request->input);
|
||||
break;
|
||||
}
|
||||
case GET_TENSOR: {
|
||||
ok = server.get_tensor(request->input, response->output);
|
||||
break;
|
||||
}
|
||||
case COPY_TENSOR: {
|
||||
ok = server.copy_tensor(request->input, response->output);
|
||||
break;
|
||||
}
|
||||
case GRAPH_COMPUTE: {
|
||||
ok = server.graph_compute(request->input, response->output);
|
||||
break;
|
||||
}
|
||||
case GET_DEVICE_MEMORY: {
|
||||
break;
|
||||
}
|
||||
case FREE_ALL_BUFFERS: {
|
||||
server.free_all_buffers();
|
||||
continue;
|
||||
}
|
||||
default: {
|
||||
fprintf(stderr, "Unknown command: %d\n", request->cmd);
|
||||
ok = false;
|
||||
}
|
||||
}
|
||||
response->status = ok;
|
||||
request->response_queue->push(response);
|
||||
}
|
||||
}
|
||||
|
||||
static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
||||
auto responseq = std::make_shared<response_queue>();
|
||||
while (true) {
|
||||
uint8_t cmd;
|
||||
if (!recv_data(sockfd, &cmd, 1)) {
|
||||
break;
|
||||
}
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
auto request = std::make_shared<rpc_request>();
|
||||
request->cmd = (rpc_cmd)cmd;
|
||||
request->response_queue = responseq;
|
||||
uint64_t input_size;
|
||||
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
||||
break;
|
||||
}
|
||||
input.resize(input_size);
|
||||
if (!recv_data(sockfd, input.data(), input_size)) {
|
||||
request->input.resize(input_size);
|
||||
if (!recv_data(sockfd, request->input.data(), input_size)) {
|
||||
break;
|
||||
}
|
||||
bool ok = true;
|
||||
auto response = std::make_shared<rpc_response>();
|
||||
switch (cmd) {
|
||||
case ALLOC_BUFFER: {
|
||||
ok = server.alloc_buffer(input, output);
|
||||
break;
|
||||
}
|
||||
case GET_ALIGNMENT: {
|
||||
server.get_alignment(output);
|
||||
break;
|
||||
}
|
||||
case GET_MAX_SIZE: {
|
||||
server.get_max_size(output);
|
||||
break;
|
||||
}
|
||||
case BUFFER_GET_BASE: {
|
||||
ok = server.buffer_get_base(input, output);
|
||||
break;
|
||||
}
|
||||
case FREE_BUFFER: {
|
||||
ok = server.free_buffer(input);
|
||||
break;
|
||||
}
|
||||
case BUFFER_CLEAR: {
|
||||
ok = server.buffer_clear(input);
|
||||
break;
|
||||
}
|
||||
case SET_TENSOR: {
|
||||
ok = server.set_tensor(input);
|
||||
break;
|
||||
}
|
||||
case GET_TENSOR: {
|
||||
ok = server.get_tensor(input, output);
|
||||
break;
|
||||
}
|
||||
case COPY_TENSOR: {
|
||||
ok = server.copy_tensor(input, output);
|
||||
break;
|
||||
}
|
||||
case ALLOC_BUFFER:
|
||||
case GET_ALIGNMENT:
|
||||
case GET_MAX_SIZE:
|
||||
case BUFFER_GET_BASE:
|
||||
case FREE_BUFFER:
|
||||
case BUFFER_CLEAR:
|
||||
case SET_TENSOR:
|
||||
case GET_TENSOR:
|
||||
case COPY_TENSOR:
|
||||
case GRAPH_COMPUTE: {
|
||||
ok = server.graph_compute(input, output);
|
||||
requestq->push(request);
|
||||
responseq->pop(&response);
|
||||
break;
|
||||
}
|
||||
case GET_DEVICE_MEMORY: {
|
||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||
output.resize(2*sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
||||
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
|
||||
response->output.resize(2*sizeof(uint64_t), 0);
|
||||
memcpy(response->output.data(), &free_mem, sizeof(free_mem));
|
||||
memcpy(response->output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
@ -1131,17 +1225,29 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|||
if (!ok) {
|
||||
break;
|
||||
}
|
||||
uint64_t output_size = output.size();
|
||||
uint64_t output_size = response->output.size();
|
||||
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
|
||||
break;
|
||||
}
|
||||
if (!send_data(sockfd, output.data(), output_size)) {
|
||||
if (!send_data(sockfd, response->output.data(), output_size)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto request = std::make_shared<rpc_request>();
|
||||
request->cmd = FREE_ALL_BUFFERS;
|
||||
requestq->push(request);
|
||||
}
|
||||
|
||||
void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
|
||||
#ifndef _WIN32
|
||||
// prevent SIGPIPE when writing to closed socket
|
||||
signal(SIGPIPE, SIG_IGN);
|
||||
#endif
|
||||
auto requestq = std::make_shared<request_queue>();
|
||||
std::thread backend_thread = std::thread([=] {
|
||||
process_requests(backend, requestq);
|
||||
});
|
||||
|
||||
std::string host;
|
||||
int port;
|
||||
if (!parse_endpoint(endpoint, host, port)) {
|
||||
|
@ -1169,7 +1275,7 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
|
|||
return;
|
||||
}
|
||||
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
||||
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
||||
rpc_serve_client(requestq, client_socket->fd, free_mem, total_mem);
|
||||
printf("Client connection closed\n");
|
||||
}
|
||||
#ifdef _WIN32
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue