Merge 005cf2e662
into 8854044561
This commit is contained in:
commit
777eb3bb0d
1 changed files with 323 additions and 41 deletions
360
ggml-rpc.cpp
360
ggml-rpc.cpp
|
@ -5,8 +5,11 @@
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <queue>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <thread>
|
||||||
|
#include <condition_variable>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
|
@ -17,6 +20,7 @@
|
||||||
# include <windows.h>
|
# include <windows.h>
|
||||||
# include <winsock2.h>
|
# include <winsock2.h>
|
||||||
#else
|
#else
|
||||||
|
# include <signal.h>
|
||||||
# include <arpa/inet.h>
|
# include <arpa/inet.h>
|
||||||
# include <sys/socket.h>
|
# include <sys/socket.h>
|
||||||
# include <sys/types.h>
|
# include <sys/types.h>
|
||||||
|
@ -82,7 +86,8 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
|
||||||
|
|
||||||
// RPC commands
|
// RPC commands
|
||||||
enum rpc_cmd {
|
enum rpc_cmd {
|
||||||
ALLOC_BUFFER = 0,
|
HELLO = 0,
|
||||||
|
ALLOC_BUFFER,
|
||||||
GET_ALIGNMENT,
|
GET_ALIGNMENT,
|
||||||
GET_MAX_SIZE,
|
GET_MAX_SIZE,
|
||||||
BUFFER_GET_BASE,
|
BUFFER_GET_BASE,
|
||||||
|
@ -91,8 +96,15 @@ enum rpc_cmd {
|
||||||
SET_TENSOR,
|
SET_TENSOR,
|
||||||
GET_TENSOR,
|
GET_TENSOR,
|
||||||
COPY_TENSOR,
|
COPY_TENSOR,
|
||||||
|
REMOTE_COPY_TENSOR,
|
||||||
GRAPH_COMPUTE,
|
GRAPH_COMPUTE,
|
||||||
GET_DEVICE_MEMORY,
|
GET_DEVICE_MEMORY,
|
||||||
|
FREE_ALL_BUFFERS,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum rpc_actor {
|
||||||
|
CLIENT = 0,
|
||||||
|
SERVER,
|
||||||
};
|
};
|
||||||
|
|
||||||
// RPC data structures
|
// RPC data structures
|
||||||
|
@ -115,6 +127,7 @@ struct ggml_backend_rpc_context {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_rpc_buffer_context {
|
struct ggml_backend_rpc_buffer_context {
|
||||||
|
std::string endpoint;
|
||||||
std::shared_ptr<socket_t> sock;
|
std::shared_ptr<socket_t> sock;
|
||||||
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
||||||
uint64_t remote_ptr;
|
uint64_t remote_ptr;
|
||||||
|
@ -205,7 +218,7 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
|
||||||
if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
|
if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (listen(sockfd, 1) < 0) {
|
if (listen(sockfd, 2) < 0) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
return sock;
|
return sock;
|
||||||
|
@ -276,6 +289,15 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
||||||
|
|
||||||
// RPC client-side implementation
|
// RPC client-side implementation
|
||||||
|
|
||||||
|
static void send_hello(std::shared_ptr<socket_t> sock, rpc_actor actor) {
|
||||||
|
// input serialization format: | actor (1 byte) |
|
||||||
|
std::vector<uint8_t> input(1, actor);
|
||||||
|
std::vector<uint8_t> output;
|
||||||
|
bool status = send_rpc_cmd(sock, HELLO, input, output);
|
||||||
|
GGML_ASSERT(status);
|
||||||
|
GGML_ASSERT(output.empty());
|
||||||
|
}
|
||||||
|
|
||||||
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
||||||
static std::mutex mutex;
|
static std::mutex mutex;
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
@ -309,6 +331,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
||||||
if (sock == nullptr) {
|
if (sock == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
send_hello(sock, CLIENT);
|
||||||
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
||||||
sockets[endpoint] = sock;
|
sockets[endpoint] = sock;
|
||||||
return sock;
|
return sock;
|
||||||
|
@ -422,6 +445,29 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
|
||||||
memcpy(data, output.data(), size);
|
memcpy(data, output.data(), size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool remote_copy_tensor(const ggml_tensor * src, ggml_tensor * dst) {
|
||||||
|
ggml_backend_buffer_t src_buffer = src->buffer;
|
||||||
|
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
|
||||||
|
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
||||||
|
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
||||||
|
// input serialization format: | rpc_tensor src | rpc_tensor dst | dst_endpoint_size (4 bytes) | dst_endpoint (dst_endpoint_size bytes) |
|
||||||
|
int input_size = 2*sizeof(rpc_tensor) + sizeof(uint32_t) + dst_ctx->endpoint.size();
|
||||||
|
std::vector<uint8_t> input(input_size, 0);
|
||||||
|
rpc_tensor rpc_src = serialize_tensor(src);
|
||||||
|
rpc_tensor rpc_dst = serialize_tensor(dst);
|
||||||
|
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
||||||
|
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
||||||
|
uint32_t dst_endpoint_size = dst_ctx->endpoint.size();
|
||||||
|
memcpy(input.data() + 2*sizeof(rpc_tensor), &dst_endpoint_size, sizeof(dst_endpoint_size));
|
||||||
|
memcpy(input.data() + 2*sizeof(rpc_tensor) + sizeof(dst_endpoint_size), dst_ctx->endpoint.c_str(), dst_endpoint_size);
|
||||||
|
std::vector<uint8_t> output;
|
||||||
|
bool status = send_rpc_cmd(src_ctx->sock, REMOTE_COPY_TENSOR, input, output);
|
||||||
|
GGML_ASSERT(status);
|
||||||
|
// output serialization format: | result (1 byte) |
|
||||||
|
GGML_ASSERT(output.size() == 1);
|
||||||
|
return output[0];
|
||||||
|
}
|
||||||
|
|
||||||
GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
|
||||||
// check if src and dst are on the same server
|
// check if src and dst are on the same server
|
||||||
ggml_backend_buffer_t src_buffer = src->buffer;
|
ggml_backend_buffer_t src_buffer = src->buffer;
|
||||||
|
@ -429,7 +475,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
|
||||||
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
ggml_backend_buffer_t dst_buffer = dst->buffer;
|
||||||
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
|
||||||
if (src_ctx->sock != dst_ctx->sock) {
|
if (src_ctx->sock != dst_ctx->sock) {
|
||||||
return false;
|
return remote_copy_tensor(src, dst);
|
||||||
}
|
}
|
||||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||||
// input serialization format: | rpc_tensor src | rpc_tensor dst |
|
// input serialization format: | rpc_tensor src | rpc_tensor dst |
|
||||||
|
@ -495,7 +541,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
||||||
if (remote_ptr != 0) {
|
if (remote_ptr != 0) {
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
||||||
ggml_backend_rpc_buffer_interface,
|
ggml_backend_rpc_buffer_interface,
|
||||||
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
|
new ggml_backend_rpc_buffer_context{buft_ctx->endpoint, sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
|
||||||
remote_size);
|
remote_size);
|
||||||
return buffer;
|
return buffer;
|
||||||
} else {
|
} else {
|
||||||
|
@ -739,6 +785,48 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
|
||||||
|
|
||||||
// RPC server-side implementation
|
// RPC server-side implementation
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class message_queue {
|
||||||
|
std::queue<T> queue;
|
||||||
|
std::mutex mutex;
|
||||||
|
std::condition_variable cvar;
|
||||||
|
|
||||||
|
public:
|
||||||
|
message_queue() {}
|
||||||
|
|
||||||
|
void push(const T &value) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
|
queue.push(value);
|
||||||
|
lock.unlock();
|
||||||
|
cvar.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
void pop(T* out) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
|
cvar.wait(lock, [this] { return queue.size() > 0; });
|
||||||
|
*out = queue.front();
|
||||||
|
queue.pop();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct rpc_response {
|
||||||
|
std::vector<uint8_t> output;
|
||||||
|
bool status;
|
||||||
|
};
|
||||||
|
|
||||||
|
using rpc_response_ptr = std::shared_ptr<rpc_response>;
|
||||||
|
using response_queue = message_queue<rpc_response_ptr>;
|
||||||
|
using response_queue_ptr = std::shared_ptr<response_queue>;
|
||||||
|
|
||||||
|
struct rpc_request {
|
||||||
|
rpc_cmd cmd;
|
||||||
|
std::vector<uint8_t> input;
|
||||||
|
response_queue_ptr response_queue;
|
||||||
|
};
|
||||||
|
using rpc_request_ptr = std::shared_ptr<rpc_request>;
|
||||||
|
using request_queue = message_queue<rpc_request_ptr>;
|
||||||
|
using request_queue_ptr = std::shared_ptr<request_queue>;
|
||||||
|
|
||||||
class rpc_server {
|
class rpc_server {
|
||||||
public:
|
public:
|
||||||
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
rpc_server(ggml_backend_t backend) : backend(backend) {}
|
||||||
|
@ -751,10 +839,13 @@ public:
|
||||||
bool free_buffer(const std::vector<uint8_t> & input);
|
bool free_buffer(const std::vector<uint8_t> & input);
|
||||||
bool buffer_clear(const std::vector<uint8_t> & input);
|
bool buffer_clear(const std::vector<uint8_t> & input);
|
||||||
bool set_tensor(const std::vector<uint8_t> & input);
|
bool set_tensor(const std::vector<uint8_t> & input);
|
||||||
|
void remote_set_tensor(std::shared_ptr<socket_t> sock, const rpc_tensor * rpc_src, const rpc_tensor * rpc_dst);
|
||||||
bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||||
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||||
|
bool remote_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||||
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
|
||||||
|
|
||||||
|
void free_all_buffers();
|
||||||
private:
|
private:
|
||||||
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
|
||||||
ggml_tensor * create_node(uint64_t id,
|
ggml_tensor * create_node(uint64_t id,
|
||||||
|
@ -980,6 +1071,65 @@ bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uin
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void rpc_server::remote_set_tensor(std::shared_ptr<socket_t> sock, const rpc_tensor * rpc_src, const rpc_tensor * rpc_dst) {
|
||||||
|
struct ggml_init_params params {
|
||||||
|
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||||
|
/*.mem_buffer =*/ NULL,
|
||||||
|
/*.no_alloc =*/ true,
|
||||||
|
};
|
||||||
|
struct ggml_context * ctx = ggml_init(params);
|
||||||
|
ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
|
||||||
|
size_t src_size = ggml_nbytes(src);
|
||||||
|
|
||||||
|
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
|
||||||
|
size_t offset = 0;
|
||||||
|
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + src_size;
|
||||||
|
std::vector<uint8_t> input(input_size, 0);
|
||||||
|
memcpy(input.data(), rpc_dst, sizeof(rpc_tensor));
|
||||||
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||||
|
ggml_backend_tensor_get(src, input.data() + sizeof(rpc_tensor) + sizeof(offset), offset, src_size);
|
||||||
|
std::vector<uint8_t> output;
|
||||||
|
bool status = send_rpc_cmd(sock, SET_TENSOR, input, output);
|
||||||
|
GGML_ASSERT(status);
|
||||||
|
ggml_free(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool rpc_server::remote_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
||||||
|
// serialization format: | rpc_tensor src | rpc_tensor dst | dst_endpoint_size (4 bytes) | dst_endpoint (dst_endpoint_size bytes) |
|
||||||
|
if (input.size() < 2*sizeof(rpc_tensor) + sizeof(uint32_t)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
|
||||||
|
const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_tensor));
|
||||||
|
uint32_t dst_endpoint_size;
|
||||||
|
memcpy(&dst_endpoint_size, input.data() + 2*sizeof(rpc_tensor), sizeof(dst_endpoint_size));
|
||||||
|
if (input.size() != 2*sizeof(rpc_tensor) + sizeof(uint32_t) + dst_endpoint_size) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// output serialization format: | result (1 byte) |
|
||||||
|
output.resize(1, 0);
|
||||||
|
|
||||||
|
const char * dst_endpoint_ptr = (const char *)(input.data() + 2*sizeof(rpc_tensor) + sizeof(uint32_t));
|
||||||
|
std::string dst_endpoint(dst_endpoint_ptr, dst_endpoint_size);
|
||||||
|
|
||||||
|
std::string host;
|
||||||
|
int port;
|
||||||
|
if (!parse_endpoint(dst_endpoint, host, port)) {
|
||||||
|
output[0] = false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto sock = socket_connect(host.c_str(), port);
|
||||||
|
if (sock == nullptr) {
|
||||||
|
output[0] = false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
send_hello(sock, SERVER);
|
||||||
|
remote_set_tensor(sock, rpc_src, rpc_dst);
|
||||||
|
output.resize(1, 0);
|
||||||
|
output[0] = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * rpc_server::create_node(uint64_t id,
|
ggml_tensor * rpc_server::create_node(uint64_t id,
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
|
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
|
||||||
|
@ -1051,97 +1201,211 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
rpc_server::~rpc_server() {
|
void rpc_server::free_all_buffers() {
|
||||||
for (auto buffer : buffers) {
|
for (auto buffer : buffers) {
|
||||||
ggml_backend_buffer_free(buffer);
|
ggml_backend_buffer_free(buffer);
|
||||||
}
|
}
|
||||||
|
buffers.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
rpc_server::~rpc_server() {
|
||||||
|
free_all_buffers();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void process_requests(ggml_backend_t backend, request_queue_ptr requestq) {
|
||||||
rpc_server server(backend);
|
rpc_server server(backend);
|
||||||
while (true) {
|
while (true) {
|
||||||
uint8_t cmd;
|
rpc_request_ptr request;
|
||||||
if (!recv_data(sockfd, &cmd, 1)) {
|
requestq->pop(&request);
|
||||||
break;
|
rpc_response_ptr response = std::make_shared<rpc_response>();
|
||||||
}
|
|
||||||
std::vector<uint8_t> input;
|
|
||||||
std::vector<uint8_t> output;
|
|
||||||
uint64_t input_size;
|
|
||||||
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
input.resize(input_size);
|
|
||||||
if (!recv_data(sockfd, input.data(), input_size)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
switch (cmd) {
|
switch (request->cmd) {
|
||||||
case ALLOC_BUFFER: {
|
case ALLOC_BUFFER: {
|
||||||
ok = server.alloc_buffer(input, output);
|
ok = server.alloc_buffer(request->input, response->output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GET_ALIGNMENT: {
|
case GET_ALIGNMENT: {
|
||||||
server.get_alignment(output);
|
server.get_alignment(response->output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GET_MAX_SIZE: {
|
case GET_MAX_SIZE: {
|
||||||
server.get_max_size(output);
|
server.get_max_size(response->output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BUFFER_GET_BASE: {
|
case BUFFER_GET_BASE: {
|
||||||
ok = server.buffer_get_base(input, output);
|
ok = server.buffer_get_base(request->input, response->output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case FREE_BUFFER: {
|
case FREE_BUFFER: {
|
||||||
ok = server.free_buffer(input);
|
ok = server.free_buffer(request->input);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BUFFER_CLEAR: {
|
case BUFFER_CLEAR: {
|
||||||
ok = server.buffer_clear(input);
|
ok = server.buffer_clear(request->input);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case SET_TENSOR: {
|
case SET_TENSOR: {
|
||||||
ok = server.set_tensor(input);
|
ok = server.set_tensor(request->input);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GET_TENSOR: {
|
case GET_TENSOR: {
|
||||||
ok = server.get_tensor(input, output);
|
ok = server.get_tensor(request->input, response->output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COPY_TENSOR: {
|
case COPY_TENSOR: {
|
||||||
ok = server.copy_tensor(input, output);
|
ok = server.copy_tensor(request->input, response->output);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case REMOTE_COPY_TENSOR: {
|
||||||
|
ok = server.remote_copy_tensor(request->input, response->output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GRAPH_COMPUTE: {
|
case GRAPH_COMPUTE: {
|
||||||
ok = server.graph_compute(input, output);
|
ok = server.graph_compute(request->input, response->output);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case GET_DEVICE_MEMORY: {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case FREE_ALL_BUFFERS: {
|
||||||
|
server.free_all_buffers();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
fprintf(stderr, "Unknown command: %d\n", request->cmd);
|
||||||
|
ok = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response->status = ok;
|
||||||
|
request->response_queue->push(response);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool recv_rpc_cmd(sockfd_t sockfd, rpc_cmd & cmd, std::vector<uint8_t> & input) {
|
||||||
|
uint8_t cmd_u8;
|
||||||
|
if (!recv_data(sockfd, &cmd_u8, 1)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
cmd = (rpc_cmd)cmd_u8;
|
||||||
|
uint64_t input_size;
|
||||||
|
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
input.resize(input_size);
|
||||||
|
if (!recv_data(sockfd, input.data(), input_size)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void rpc_serve_client(request_queue_ptr requestq, std::shared_ptr<socket_t> sock, size_t free_mem, size_t total_mem) {
|
||||||
|
auto responseq = std::make_shared<response_queue>();
|
||||||
|
while (true) {
|
||||||
|
auto request = std::make_shared<rpc_request>();
|
||||||
|
if (!recv_rpc_cmd(sock->fd, request->cmd, request->input)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
request->response_queue = responseq;
|
||||||
|
bool ok = true;
|
||||||
|
auto response = std::make_shared<rpc_response>();
|
||||||
|
switch (request->cmd) {
|
||||||
|
case ALLOC_BUFFER:
|
||||||
|
case GET_ALIGNMENT:
|
||||||
|
case GET_MAX_SIZE:
|
||||||
|
case BUFFER_GET_BASE:
|
||||||
|
case FREE_BUFFER:
|
||||||
|
case BUFFER_CLEAR:
|
||||||
|
case SET_TENSOR:
|
||||||
|
case GET_TENSOR:
|
||||||
|
case COPY_TENSOR:
|
||||||
|
case REMOTE_COPY_TENSOR:
|
||||||
|
case GRAPH_COMPUTE: {
|
||||||
|
requestq->push(request);
|
||||||
|
responseq->pop(&response);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GET_DEVICE_MEMORY: {
|
case GET_DEVICE_MEMORY: {
|
||||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||||
output.resize(2*sizeof(uint64_t), 0);
|
response->output.resize(2*sizeof(uint64_t), 0);
|
||||||
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
memcpy(response->output.data(), &free_mem, sizeof(free_mem));
|
||||||
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
|
memcpy(response->output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
fprintf(stderr, "Unknown command: %d\n", cmd);
|
fprintf(stderr, "Unexpected command: %d\n", request->cmd);
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
uint64_t output_size = output.size();
|
uint64_t output_size = response->output.size();
|
||||||
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
|
if (!send_data(sock->fd, &output_size, sizeof(output_size))) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (!send_data(sockfd, output.data(), output_size)) {
|
if (!send_data(sock->fd, response->output.data(), output_size)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
auto request = std::make_shared<rpc_request>();
|
||||||
|
request->cmd = FREE_ALL_BUFFERS;
|
||||||
|
requestq->push(request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void rpc_serve_server(request_queue_ptr requestq, std::shared_ptr<socket_t> sock) {
|
||||||
|
auto responseq = std::make_shared<response_queue>();
|
||||||
|
auto request = std::make_shared<rpc_request>();
|
||||||
|
if (!recv_rpc_cmd(sock->fd, request->cmd, request->input)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (request->cmd != SET_TENSOR) {
|
||||||
|
fprintf(stderr, "Unexpected command: %d\n", request->cmd);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
request->response_queue = responseq;
|
||||||
|
auto response = std::make_shared<rpc_response>();
|
||||||
|
requestq->push(request);
|
||||||
|
responseq->pop(&response);
|
||||||
|
uint64_t output_size = response->output.size();
|
||||||
|
if (!send_data(sock->fd, &output_size, sizeof(output_size))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
send_data(sock->fd, response->output.data(), output_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool recv_hello(std::shared_ptr<socket_t> sock, rpc_actor & actor) {
|
||||||
|
rpc_cmd cmd;
|
||||||
|
std::vector<uint8_t> input;
|
||||||
|
if (!recv_rpc_cmd(sock->fd, cmd, input)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (cmd != HELLO || input.size() != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (input[0] != CLIENT && input[0] != SERVER) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
actor = (rpc_actor)input[0];
|
||||||
|
uint64_t output_size = 0;
|
||||||
|
if (!send_data(sock->fd, &output_size, sizeof(output_size))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::mutex client_mutex;
|
||||||
|
static std::mutex server_mutex;
|
||||||
|
|
||||||
void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
|
void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
|
||||||
|
#ifndef _WIN32
|
||||||
|
// prevent SIGPIPE when writing to closed socket
|
||||||
|
signal(SIGPIPE, SIG_IGN);
|
||||||
|
#endif
|
||||||
|
auto requestq = std::make_shared<request_queue>();
|
||||||
|
std::thread backend_thread = std::thread([=] {
|
||||||
|
process_requests(backend, requestq);
|
||||||
|
});
|
||||||
|
|
||||||
std::string host;
|
std::string host;
|
||||||
int port;
|
int port;
|
||||||
if (!parse_endpoint(endpoint, host, port)) {
|
if (!parse_endpoint(endpoint, host, port)) {
|
||||||
|
@ -1168,9 +1432,27 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
|
||||||
fprintf(stderr, "Failed to accept client connection\n");
|
fprintf(stderr, "Failed to accept client connection\n");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
rpc_actor actor;
|
||||||
|
if (!recv_hello(client_socket, actor)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (actor == CLIENT) {
|
||||||
|
std::thread client_thread = std::thread([=] {
|
||||||
|
std::lock_guard<std::mutex> lock(client_mutex);
|
||||||
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
||||||
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
rpc_serve_client(requestq, client_socket, free_mem, total_mem);
|
||||||
printf("Client connection closed\n");
|
printf("Client connection closed\n");
|
||||||
|
});
|
||||||
|
client_thread.detach();
|
||||||
|
} else {
|
||||||
|
std::thread server_thread = std::thread([=] {
|
||||||
|
std::lock_guard<std::mutex> lock(server_mutex);
|
||||||
|
printf("Accepted connection from another server\n");
|
||||||
|
rpc_serve_server(requestq, client_socket);
|
||||||
|
printf("Server connection closed\n");
|
||||||
|
});
|
||||||
|
server_thread.detach();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
WSACleanup();
|
WSACleanup();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue