rpc : copy tensors across servers

Add new cmd REMOTE_COPY_TENSOR for copying a tensor from one server to
another.
This commit is contained in:
Radoslav Gerganov 2024-06-18 16:28:46 +03:00
parent d47e1371b0
commit 005cf2e662

View file

@ -86,7 +86,8 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
// RPC commands
enum rpc_cmd {
ALLOC_BUFFER = 0,
HELLO = 0,
ALLOC_BUFFER,
GET_ALIGNMENT,
GET_MAX_SIZE,
BUFFER_GET_BASE,
@ -95,11 +96,17 @@ enum rpc_cmd {
SET_TENSOR,
GET_TENSOR,
COPY_TENSOR,
REMOTE_COPY_TENSOR,
GRAPH_COMPUTE,
GET_DEVICE_MEMORY,
FREE_ALL_BUFFERS,
};
enum rpc_actor {
CLIENT = 0,
SERVER,
};
// RPC data structures
static ggml_guid_t ggml_backend_rpc_guid() {
@ -120,6 +127,7 @@ struct ggml_backend_rpc_context {
};
struct ggml_backend_rpc_buffer_context {
std::string endpoint;
std::shared_ptr<socket_t> sock;
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
uint64_t remote_ptr;
@ -210,7 +218,7 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
return nullptr;
}
if (listen(sockfd, 1) < 0) {
if (listen(sockfd, 2) < 0) {
return nullptr;
}
return sock;
@ -281,6 +289,15 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
// RPC client-side implementation
static void send_hello(std::shared_ptr<socket_t> sock, rpc_actor actor) {
// input serialization format: | actor (1 byte) |
std::vector<uint8_t> input(1, actor);
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, HELLO, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.empty());
}
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
@ -314,6 +331,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
if (sock == nullptr) {
return nullptr;
}
send_hello(sock, CLIENT);
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
sockets[endpoint] = sock;
return sock;
@ -427,6 +445,29 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
memcpy(data, output.data(), size);
}
static bool remote_copy_tensor(const ggml_tensor * src, ggml_tensor * dst) {
ggml_backend_buffer_t src_buffer = src->buffer;
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
ggml_backend_buffer_t dst_buffer = dst->buffer;
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
// input serialization format: | rpc_tensor src | rpc_tensor dst | dst_endpoint_size (4 bytes) | dst_endpoint (dst_endpoint_size bytes) |
int input_size = 2*sizeof(rpc_tensor) + sizeof(uint32_t) + dst_ctx->endpoint.size();
std::vector<uint8_t> input(input_size, 0);
rpc_tensor rpc_src = serialize_tensor(src);
rpc_tensor rpc_dst = serialize_tensor(dst);
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
uint32_t dst_endpoint_size = dst_ctx->endpoint.size();
memcpy(input.data() + 2*sizeof(rpc_tensor), &dst_endpoint_size, sizeof(dst_endpoint_size));
memcpy(input.data() + 2*sizeof(rpc_tensor) + sizeof(dst_endpoint_size), dst_ctx->endpoint.c_str(), dst_endpoint_size);
std::vector<uint8_t> output;
bool status = send_rpc_cmd(src_ctx->sock, REMOTE_COPY_TENSOR, input, output);
GGML_ASSERT(status);
// output serialization format: | result (1 byte) |
GGML_ASSERT(output.size() == 1);
return output[0];
}
GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
// check if src and dst are on the same server
ggml_backend_buffer_t src_buffer = src->buffer;
@ -434,7 +475,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
ggml_backend_buffer_t dst_buffer = dst->buffer;
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
if (src_ctx->sock != dst_ctx->sock) {
return false;
return remote_copy_tensor(src, dst);
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
// input serialization format: | rpc_tensor src | rpc_tensor dst |
@ -500,7 +541,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
if (remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
new ggml_backend_rpc_buffer_context{buft_ctx->endpoint, sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
remote_size);
return buffer;
} else {
@ -798,8 +839,10 @@ public:
bool free_buffer(const std::vector<uint8_t> & input);
bool buffer_clear(const std::vector<uint8_t> & input);
bool set_tensor(const std::vector<uint8_t> & input);
void remote_set_tensor(std::shared_ptr<socket_t> sock, const rpc_tensor * rpc_src, const rpc_tensor * rpc_dst);
bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
bool remote_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
void free_all_buffers();
@ -1028,6 +1071,65 @@ bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uin
return true;
}
void rpc_server::remote_set_tensor(std::shared_ptr<socket_t> sock, const rpc_tensor * rpc_src, const rpc_tensor * rpc_dst) {
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
struct ggml_context * ctx = ggml_init(params);
ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
size_t src_size = ggml_nbytes(src);
// input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
size_t offset = 0;
size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + src_size;
std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), rpc_dst, sizeof(rpc_tensor));
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
ggml_backend_tensor_get(src, input.data() + sizeof(rpc_tensor) + sizeof(offset), offset, src_size);
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, SET_TENSOR, input, output);
GGML_ASSERT(status);
ggml_free(ctx);
}
bool rpc_server::remote_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
// serialization format: | rpc_tensor src | rpc_tensor dst | dst_endpoint_size (4 bytes) | dst_endpoint (dst_endpoint_size bytes) |
if (input.size() < 2*sizeof(rpc_tensor) + sizeof(uint32_t)) {
return false;
}
const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_tensor));
uint32_t dst_endpoint_size;
memcpy(&dst_endpoint_size, input.data() + 2*sizeof(rpc_tensor), sizeof(dst_endpoint_size));
if (input.size() != 2*sizeof(rpc_tensor) + sizeof(uint32_t) + dst_endpoint_size) {
return false;
}
// output serialization format: | result (1 byte) |
output.resize(1, 0);
const char * dst_endpoint_ptr = (const char *)(input.data() + 2*sizeof(rpc_tensor) + sizeof(uint32_t));
std::string dst_endpoint(dst_endpoint_ptr, dst_endpoint_size);
std::string host;
int port;
if (!parse_endpoint(dst_endpoint, host, port)) {
output[0] = false;
return true;
}
auto sock = socket_connect(host.c_str(), port);
if (sock == nullptr) {
output[0] = false;
return true;
}
send_hello(sock, SERVER);
remote_set_tensor(sock, rpc_src, rpc_dst);
output.resize(1, 0);
output[0] = true;
return true;
}
ggml_tensor * rpc_server::create_node(uint64_t id,
struct ggml_context * ctx,
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
@ -1154,6 +1256,10 @@ static void process_requests(ggml_backend_t backend, request_queue_ptr requestq)
ok = server.copy_tensor(request->input, response->output);
break;
}
case REMOTE_COPY_TENSOR: {
ok = server.remote_copy_tensor(request->input, response->output);
break;
}
case GRAPH_COMPUTE: {
ok = server.graph_compute(request->input, response->output);
break;
@ -1175,27 +1281,34 @@ static void process_requests(ggml_backend_t backend, request_queue_ptr requestq)
}
}
static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
static bool recv_rpc_cmd(sockfd_t sockfd, rpc_cmd & cmd, std::vector<uint8_t> & input) {
uint8_t cmd_u8;
if (!recv_data(sockfd, &cmd_u8, 1)) {
return false;
}
cmd = (rpc_cmd)cmd_u8;
uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
return false;
}
input.resize(input_size);
if (!recv_data(sockfd, input.data(), input_size)) {
return false;
}
return true;
}
static void rpc_serve_client(request_queue_ptr requestq, std::shared_ptr<socket_t> sock, size_t free_mem, size_t total_mem) {
auto responseq = std::make_shared<response_queue>();
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
auto request = std::make_shared<rpc_request>();
request->cmd = (rpc_cmd)cmd;
if (!recv_rpc_cmd(sock->fd, request->cmd, request->input)) {
break;
}
request->response_queue = responseq;
uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
break;
}
request->input.resize(input_size);
if (!recv_data(sockfd, request->input.data(), input_size)) {
break;
}
bool ok = true;
auto response = std::make_shared<rpc_response>();
switch (cmd) {
switch (request->cmd) {
case ALLOC_BUFFER:
case GET_ALIGNMENT:
case GET_MAX_SIZE:
@ -1205,6 +1318,7 @@ static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t
case SET_TENSOR:
case GET_TENSOR:
case COPY_TENSOR:
case REMOTE_COPY_TENSOR:
case GRAPH_COMPUTE: {
requestq->push(request);
responseq->pop(&response);
@ -1218,7 +1332,7 @@ static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t
break;
}
default: {
fprintf(stderr, "Unknown command: %d\n", cmd);
fprintf(stderr, "Unexpected command: %d\n", request->cmd);
ok = false;
}
}
@ -1226,10 +1340,10 @@ static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t
break;
}
uint64_t output_size = response->output.size();
if (!send_data(sockfd, &output_size, sizeof(output_size))) {
if (!send_data(sock->fd, &output_size, sizeof(output_size))) {
break;
}
if (!send_data(sockfd, response->output.data(), output_size)) {
if (!send_data(sock->fd, response->output.data(), output_size)) {
break;
}
}
@ -1238,6 +1352,50 @@ static void rpc_serve_client(request_queue_ptr requestq, sockfd_t sockfd, size_t
requestq->push(request);
}
static void rpc_serve_server(request_queue_ptr requestq, std::shared_ptr<socket_t> sock) {
auto responseq = std::make_shared<response_queue>();
auto request = std::make_shared<rpc_request>();
if (!recv_rpc_cmd(sock->fd, request->cmd, request->input)) {
return;
}
if (request->cmd != SET_TENSOR) {
fprintf(stderr, "Unexpected command: %d\n", request->cmd);
return;
}
request->response_queue = responseq;
auto response = std::make_shared<rpc_response>();
requestq->push(request);
responseq->pop(&response);
uint64_t output_size = response->output.size();
if (!send_data(sock->fd, &output_size, sizeof(output_size))) {
return;
}
send_data(sock->fd, response->output.data(), output_size);
}
static bool recv_hello(std::shared_ptr<socket_t> sock, rpc_actor & actor) {
rpc_cmd cmd;
std::vector<uint8_t> input;
if (!recv_rpc_cmd(sock->fd, cmd, input)) {
return false;
}
if (cmd != HELLO || input.size() != 1) {
return false;
}
if (input[0] != CLIENT && input[0] != SERVER) {
return false;
}
actor = (rpc_actor)input[0];
uint64_t output_size = 0;
if (!send_data(sock->fd, &output_size, sizeof(output_size))) {
return false;
}
return true;
}
static std::mutex client_mutex;
static std::mutex server_mutex;
void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
#ifndef _WIN32
// prevent SIGPIPE when writing to closed socket
@ -1274,9 +1432,27 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
fprintf(stderr, "Failed to accept client connection\n");
return;
}
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
rpc_serve_client(requestq, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n");
rpc_actor actor;
if (!recv_hello(client_socket, actor)) {
continue;
}
if (actor == CLIENT) {
std::thread client_thread = std::thread([=] {
std::lock_guard<std::mutex> lock(client_mutex);
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
rpc_serve_client(requestq, client_socket, free_mem, total_mem);
printf("Client connection closed\n");
});
client_thread.detach();
} else {
std::thread server_thread = std::thread([=] {
std::lock_guard<std::mutex> lock(server_mutex);
printf("Accepted connection from another server\n");
rpc_serve_server(requestq, client_socket);
printf("Server connection closed\n");
});
server_thread.detach();
}
}
#ifdef _WIN32
WSACleanup();