From ef9be327916740d2bc318bdd6516e5a1b8cb9886 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 7 May 2024 11:02:02 +0300 Subject: [PATCH] wrap sockfd into a struct --- ggml-rpc.cpp | 81 ++++++++++++++++++++++++++++------------------------ 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index 36dd671f3..3ff5700aa 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -26,25 +26,33 @@ // RPC data structures +struct sockfd { + int fd; + sockfd(int fd) : fd(fd) {} + ~sockfd() { + close(fd); + } +}; + static ggml_guid_t ggml_backend_rpc_guid() { static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03}; return &guid; } struct ggml_backend_rpc_buffer_type_context { - int sockfd; + std::shared_ptr sock; std::string name; }; struct ggml_backend_rpc_context { std::string endpoint; std::string name; - int sockfd; + std::shared_ptr sock; ggml_backend_buffer_type_t buft; }; struct ggml_backend_rpc_buffer_context { - int sockfd; + std::shared_ptr sock; std::unordered_map base_cache; uint64_t remote_ptr; std::string name; @@ -53,33 +61,31 @@ struct ggml_backend_rpc_buffer_context { // RPC helper functions -static int socket_connect(const char * host, int port) { +static std::shared_ptr socket_connect(const char * host, int port) { struct sockaddr_in addr; int sock = socket(AF_INET, SOCK_STREAM, 0); if (sock < 0) { - return -1; + return nullptr; } + auto sock_ptr = std::make_shared(sock); // set TCP_NODELAY to disable Nagle's algorithm int flag = 1; - int ret = setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + int ret = setsockopt(sock_ptr->fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); if (ret < 0) { - close(sock); - return -1; + return nullptr; } addr.sin_family = AF_INET; addr.sin_port = htons(port); struct hostent * server = gethostbyname(host); if (server == NULL) { fprintf(stderr, "Cannot resolve host '%s'\n", host); - close(sock); - return -1; + return nullptr; } bcopy((char *)server->h_addr, (char *)&addr.sin_addr.s_addr, server->h_length); - if (connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) { - close(sock); - return -1; + if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; } - return sock; + return sock_ptr; } static bool send_data(int sockfd, const void * data, size_t size) { @@ -108,20 +114,20 @@ static bool recv_data(int sockfd, void * data, size_t size) { // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(int sockfd, enum rpc_cmd cmd, const std::vector & input, std::vector & output) { +static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const std::vector & input, std::vector & output) { uint8_t cmd_byte = cmd; - if (!send_data(sockfd, &cmd_byte, sizeof(cmd_byte))) { + if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { return false; } uint64_t input_size = input.size(); - if (!send_data(sockfd, &input_size, sizeof(input_size))) { + if (!send_data(sock->fd, &input_size, sizeof(input_size))) { return false; } - if (!send_data(sockfd, input.data(), input.size())) { + if (!send_data(sock->fd, input.data(), input.size())) { return false; } uint64_t output_size; - if (!recv_data(sockfd, &output_size, sizeof(output_size))) { + if (!recv_data(sock->fd, &output_size, sizeof(output_size))) { return false; } if (output_size == 0) { @@ -129,7 +135,7 @@ static bool send_rpc_cmd(int sockfd, enum rpc_cmd cmd, const std::vectorfd, output.data(), output_size)) { return false; } return true; @@ -149,7 +155,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t uint64_t remote_ptr = ctx->remote_ptr; memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); std::vector output; - bool status = send_rpc_cmd(ctx->sockfd, FREE_BUFFER, input, output); + bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.empty()); delete ctx; @@ -165,7 +171,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b uint64_t remote_ptr = ctx->remote_ptr; memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); std::vector output; - bool status = send_rpc_cmd(ctx->sockfd, BUFFER_GET_BASE, input, output); + bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == sizeof(uint64_t)); // output serialization format: | base_ptr (8 bytes) | @@ -241,7 +247,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); std::vector output; - bool status = send_rpc_cmd(ctx->sockfd, SET_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output); GGML_ASSERT(status); } @@ -255,7 +261,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size)); std::vector output; - bool status = send_rpc_cmd(ctx->sockfd, GET_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == size); // output serialization format: | data (size bytes) | @@ -268,7 +274,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; ggml_backend_buffer_t dst_buffer = dst->buffer; ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; - if (src_ctx->sockfd != dst_ctx->sockfd) { + if (src_ctx->sock != dst_ctx->sock) { return false; } ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; @@ -280,7 +286,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b memcpy(input.data(), &rpc_src, sizeof(rpc_src)); memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst)); std::vector output; - bool status = send_rpc_cmd(ctx->sockfd, COPY_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output); GGML_ASSERT(status); // output serialization format: | result (1 byte) | GGML_ASSERT(output.size() == 1); @@ -295,7 +301,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr)); memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value)); std::vector output; - bool status = send_rpc_cmd(ctx->sockfd, BUFFER_CLEAR, input, output); + bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output); GGML_ASSERT(status); } @@ -323,7 +329,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer std::vector input(input_size, 0); memcpy(input.data(), &size, sizeof(size)); std::vector output; - bool status = send_rpc_cmd(buft_ctx->sockfd, ALLOC_BUFFER, input, output); + bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | @@ -334,7 +340,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{buft_ctx->sockfd, {}, remote_ptr, "RPC"}, + new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"}, remote_size); return buffer; @@ -363,7 +369,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - return buft_ctx->sockfd == rpc_ctx->sockfd; + return buft_ctx->sock == rpc_ctx->sock; } static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { @@ -386,7 +392,6 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context; - //close(rpc_ctx->sockfd); delete buft_ctx; delete rpc_ctx->buft; delete rpc_ctx; @@ -446,7 +451,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t std::vector input; serialize_graph(cgraph, input); std::vector output; - bool status = send_rpc_cmd(rpc_ctx->sockfd, GRAPH_COMPUTE, input, output); + bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 1); return (enum ggml_status)output[0]; @@ -496,14 +501,14 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) { size_t pos = endpoint.find(":"); std::string host = endpoint.substr(0, pos); int port = std::stoi(endpoint.substr(pos + 1)); - int sockfd = socket_connect(host.c_str(), port); - if (sockfd < 0) { + auto sock = socket_connect(host.c_str(), port); + if (sock == nullptr) { return nullptr; } ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { - /* .sockfd = */ sockfd, - /* .name = */ "RPC" + std::to_string(sockfd) + /* .sock = */ sock, + /* .name = */ "RPC" + std::to_string(sock->fd) }; ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { @@ -513,8 +518,8 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { /* .endpoint = */ endpoint, - /* .name = */ "RPC" + std::to_string(sockfd), - /* .sockfd = */ sockfd, + /* .name = */ "RPC" + std::to_string(sock->fd), + /* .sock = */ sock, /* .buft = */ buft };