diff --git a/CMakeLists.txt b/CMakeLists.txt index d61fa4520..feb6f39d0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -498,6 +498,10 @@ endif() if (LLAMA_RPC) add_compile_definitions(GGML_USE_RPC) + if (WIN32) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ws2_32) + endif() + set(GGML_HEADERS_RPC ggml-rpc.h) set(GGML_SOURCES_RPC ggml-rpc.cpp) endif() diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index f809aec33..b81c4184b 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -9,14 +9,17 @@ #include "ggml-rpc.h" #include #include -#include -#include -#include -#include -#include +#ifndef _WIN32 +# include +# include +# include +# include +# include +# include +# include +#endif #include #include -#include static ggml_backend_t create_backend() { ggml_backend_t backend = NULL; @@ -52,10 +55,24 @@ static void get_backend_memory(size_t * free_mem, size_t * total_mem) { #endif } -static int create_server_socket(const char * host, int port) { - int sockfd = socket(AF_INET, SOCK_STREAM, 0); - if (sockfd < 0) { - return -1; +static std::shared_ptr make_socket(sockfd_t fd) { +#ifdef _WIN32 + if (fd == INVALID_SOCKET) { + return nullptr; + } +#else + if (fd < 0) { + return nullptr; + } +#endif + return std::make_shared(fd); +} + +static std::shared_ptr create_server_socket(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + auto sock = make_socket(sockfd); + if (sock == nullptr) { + return nullptr; } struct sockaddr_in serv_addr; @@ -64,16 +81,24 @@ static int create_server_socket(const char * host, int port) { serv_addr.sin_port = htons(port); if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { - return -1; + return nullptr; } if (listen(sockfd, 5) < 0) { - return -1; + return nullptr; } - return sockfd; + return sock; } int main(int argc, char * argv[]) { +#ifdef _WIN32 + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + fprintf(stderr, "WSAStartup failed: %d\n", res); + return 1; + } +#endif if (argc < 3) { fprintf(stderr, "Usage: %s \n", argv[0]); return 1; @@ -88,33 +113,33 @@ int main(int argc, char * argv[]) } printf("Starting RPC server on %s:%d\n", host, port); - int server_socket = create_server_socket(host, port); - if (server_socket < 0) { + auto server_socket = create_server_socket(host, port); + if (server_socket == nullptr) { fprintf(stderr, "Failed to create server socket\n"); return 1; } while (true) { - struct sockaddr_in cli_addr; - socklen_t clilen = sizeof(cli_addr); - int client_socket = accept(server_socket, (struct sockaddr *) &cli_addr, &clilen); - if (client_socket < 0) { + auto client_socket_fd = accept(server_socket->fd, NULL, NULL); + auto client_socket = make_socket(client_socket_fd); + if (client_socket == nullptr) { fprintf(stderr, "Failed to accept client connection\n"); return 1; } // set TCP_NODELAY to disable Nagle's algorithm int flag = 1; - int ret = setsockopt(client_socket, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(int)); + int ret = setsockopt(client_socket->fd, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(int)); if (ret < 0) { fprintf(stderr, "Failed to set TCP_NODELAY\n"); - close(client_socket); continue; } size_t free_mem, total_mem; get_backend_memory(&free_mem, &total_mem); printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); - rpc_serve_client(backend, client_socket, free_mem, total_mem); + rpc_serve_client(backend, client_socket->fd, free_mem, total_mem); printf("Client connection closed\n"); - close(client_socket); } +#ifdef _WIN32 + WSACleanup(); +#endif return 0; } diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index 1c21c44a0..84768509c 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -2,18 +2,19 @@ #include "ggml.h" #include "ggml-backend-impl.h" -#include #include #include #include #include -#include -#include -#include -#include -#include +#ifndef _WIN32 +# include +# include +# include +# include +# include +# include +#endif #include -#include #define UNUSED GGML_UNUSED @@ -24,15 +25,11 @@ #define GGML_PRINT_DEBUG(...) #endif -// RPC data structures +#ifdef _WIN32 +using ssize_t = __int64; +#endif -struct sockfd { - int fd; - sockfd(int fd) : fd(fd) {} - ~sockfd() { - close(fd); - } -}; +// RPC data structures static ggml_guid_t ggml_backend_rpc_guid() { static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03}; @@ -40,7 +37,7 @@ static ggml_guid_t ggml_backend_rpc_guid() { } struct ggml_backend_rpc_buffer_type_context { - std::shared_ptr sock; + std::shared_ptr sock; std::string name; size_t alignment; size_t max_size; @@ -49,27 +46,47 @@ struct ggml_backend_rpc_buffer_type_context { struct ggml_backend_rpc_context { std::string endpoint; std::string name; - std::shared_ptr sock; + std::shared_ptr sock; ggml_backend_buffer_type_t buft; }; struct ggml_backend_rpc_buffer_context { - std::shared_ptr sock; + std::shared_ptr sock; std::unordered_map base_cache; uint64_t remote_ptr; std::string name; }; - // RPC helper functions -static std::shared_ptr socket_connect(const char * host, int port) { - struct sockaddr_in addr; - int sock = socket(AF_INET, SOCK_STREAM, 0); - if (sock < 0) { +socket_t::~socket_t() { +#ifdef _WIN32 + closesocket(this->fd); +#else + close(this->fd); +#endif +} + +static std::shared_ptr make_socket(sockfd_t fd) { +#ifdef _WIN32 + if (fd == INVALID_SOCKET) { + return nullptr; + } +#else + if (fd < 0) { + return nullptr; + } +#endif + return std::make_shared(fd); +} + +static std::shared_ptr socket_connect(const char * host, int port) { + struct sockaddr_in addr; + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + auto sock_ptr = make_socket(sockfd); + if (sock_ptr == nullptr) { return nullptr; } - auto sock_ptr = std::make_shared(sock); // set TCP_NODELAY to disable Nagle's algorithm int flag = 1; int ret = setsockopt(sock_ptr->fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); @@ -83,17 +100,17 @@ static std::shared_ptr socket_connect(const char * host, int port) { fprintf(stderr, "Cannot resolve host '%s'\n", host); return nullptr; } - bcopy((char *)server->h_addr, (char *)&addr.sin_addr.s_addr, server->h_length); + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { return nullptr; } return sock_ptr; } -static bool send_data(int sockfd, const void * data, size_t size) { +static bool send_data(sockfd_t sockfd, const void * data, size_t size) { size_t bytes_sent = 0; while (bytes_sent < size) { - ssize_t n = send(sockfd, (const uint8_t *)data + bytes_sent, size - bytes_sent, 0); + ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0); if (n < 0) { return false; } @@ -102,10 +119,10 @@ static bool send_data(int sockfd, const void * data, size_t size) { return true; } -static bool recv_data(int sockfd, void * data, size_t size) { +static bool recv_data(sockfd_t sockfd, void * data, size_t size) { size_t bytes_recv = 0; while (bytes_recv < size) { - ssize_t n = recv(sockfd, (uint8_t *)data + bytes_recv, size - bytes_recv, 0); + ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0); if (n <= 0) { return false; } @@ -116,7 +133,7 @@ static bool recv_data(int sockfd, void * data, size_t size) { // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const std::vector & input, std::vector & output) { +static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const std::vector & input, std::vector & output) { uint8_t cmd_byte = cmd; if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { return false; @@ -348,7 +365,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer return buffer; } -static size_t get_alignment(const std::shared_ptr & sock) { +static size_t get_alignment(const std::shared_ptr & sock) { // input serialization format: | 0 bytes | std::vector input; std::vector output; @@ -366,7 +383,7 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_ return buft_ctx->alignment; } -static size_t get_max_size(const std::shared_ptr & sock) { +static size_t get_max_size(const std::shared_ptr & sock) { // input serialization format: | 0 bytes | std::vector input; std::vector output; @@ -522,6 +539,15 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) { if (instances.find(endpoint) != instances.end()) { return instances[endpoint]; } +#ifdef _WIN32 + { + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return nullptr; + } + } +#endif GGML_PRINT_DEBUG("Connecting to %s\n", endpoint.c_str()); // split the endpoint into host and port size_t pos = endpoint.find(":"); @@ -565,7 +591,7 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); } -static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { +static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { // input serialization format: | 0 bytes | std::vector input; std::vector output; @@ -779,7 +805,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector ggml_free(ctx); } -void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem) { +void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) { while (true) { uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { diff --git a/ggml-rpc.h b/ggml-rpc.h index a7dfaef71..8472d19df 100644 --- a/ggml-rpc.h +++ b/ggml-rpc.h @@ -3,11 +3,35 @@ #include "ggml.h" #include "ggml-backend.h" #include +#include +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#endif #ifdef __cplusplus extern "C" { #endif +// cross-platform socket fd +#ifdef _WIN32 +typedef SOCKET sockfd_t; +#else +typedef int sockfd_t; +#endif + +// cross-platform socket +struct socket_t { + sockfd_t fd; + socket_t(sockfd_t fd) : fd(fd) {} + ~socket_t(); +}; + + // ggml_tensor is serialized into rpc_tensor struct rpc_tensor { uint64_t id; @@ -50,7 +74,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total); -GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem); +GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem); #ifdef __cplusplus }