win32 support

This commit is contained in:
Radoslav Gerganov 2024-05-09 15:10:21 +03:00
parent 3d55181445
commit 7c00fd5184
4 changed files with 137 additions and 58 deletions

View file

@ -498,6 +498,10 @@ endif()
if (LLAMA_RPC) if (LLAMA_RPC)
add_compile_definitions(GGML_USE_RPC) add_compile_definitions(GGML_USE_RPC)
if (WIN32)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ws2_32)
endif()
set(GGML_HEADERS_RPC ggml-rpc.h) set(GGML_HEADERS_RPC ggml-rpc.h)
set(GGML_SOURCES_RPC ggml-rpc.cpp) set(GGML_SOURCES_RPC ggml-rpc.cpp)
endif() endif()

View file

@ -9,14 +9,17 @@
#include "ggml-rpc.h" #include "ggml-rpc.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <sys/types.h> #ifndef _WIN32
#include <sys/socket.h> # include <sys/socket.h>
#include <netinet/in.h> # include <sys/types.h>
#include <netinet/tcp.h> # include <arpa/inet.h>
#include <arpa/inet.h> # include <netinet/in.h>
# include <netinet/tcp.h>
# include <netdb.h>
# include <unistd.h>
#endif
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <unistd.h>
static ggml_backend_t create_backend() { static ggml_backend_t create_backend() {
ggml_backend_t backend = NULL; ggml_backend_t backend = NULL;
@ -52,10 +55,24 @@ static void get_backend_memory(size_t * free_mem, size_t * total_mem) {
#endif #endif
} }
static int create_server_socket(const char * host, int port) { static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
int sockfd = socket(AF_INET, SOCK_STREAM, 0); #ifdef _WIN32
if (sockfd < 0) { if (fd == INVALID_SOCKET) {
return -1; return nullptr;
}
#else
if (fd < 0) {
return nullptr;
}
#endif
return std::make_shared<socket_t>(fd);
}
static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
auto sock = make_socket(sockfd);
if (sock == nullptr) {
return nullptr;
} }
struct sockaddr_in serv_addr; struct sockaddr_in serv_addr;
@ -64,16 +81,24 @@ static int create_server_socket(const char * host, int port) {
serv_addr.sin_port = htons(port); serv_addr.sin_port = htons(port);
if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
return -1; return nullptr;
} }
if (listen(sockfd, 5) < 0) { if (listen(sockfd, 5) < 0) {
return -1; return nullptr;
} }
return sockfd; return sock;
} }
int main(int argc, char * argv[]) int main(int argc, char * argv[])
{ {
#ifdef _WIN32
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (res != 0) {
fprintf(stderr, "WSAStartup failed: %d\n", res);
return 1;
}
#endif
if (argc < 3) { if (argc < 3) {
fprintf(stderr, "Usage: %s <host> <port>\n", argv[0]); fprintf(stderr, "Usage: %s <host> <port>\n", argv[0]);
return 1; return 1;
@ -88,33 +113,33 @@ int main(int argc, char * argv[])
} }
printf("Starting RPC server on %s:%d\n", host, port); printf("Starting RPC server on %s:%d\n", host, port);
int server_socket = create_server_socket(host, port); auto server_socket = create_server_socket(host, port);
if (server_socket < 0) { if (server_socket == nullptr) {
fprintf(stderr, "Failed to create server socket\n"); fprintf(stderr, "Failed to create server socket\n");
return 1; return 1;
} }
while (true) { while (true) {
struct sockaddr_in cli_addr; auto client_socket_fd = accept(server_socket->fd, NULL, NULL);
socklen_t clilen = sizeof(cli_addr); auto client_socket = make_socket(client_socket_fd);
int client_socket = accept(server_socket, (struct sockaddr *) &cli_addr, &clilen); if (client_socket == nullptr) {
if (client_socket < 0) {
fprintf(stderr, "Failed to accept client connection\n"); fprintf(stderr, "Failed to accept client connection\n");
return 1; return 1;
} }
// set TCP_NODELAY to disable Nagle's algorithm // set TCP_NODELAY to disable Nagle's algorithm
int flag = 1; int flag = 1;
int ret = setsockopt(client_socket, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(int)); int ret = setsockopt(client_socket->fd, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(int));
if (ret < 0) { if (ret < 0) {
fprintf(stderr, "Failed to set TCP_NODELAY\n"); fprintf(stderr, "Failed to set TCP_NODELAY\n");
close(client_socket);
continue; continue;
} }
size_t free_mem, total_mem; size_t free_mem, total_mem;
get_backend_memory(&free_mem, &total_mem); get_backend_memory(&free_mem, &total_mem);
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
rpc_serve_client(backend, client_socket, free_mem, total_mem); rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n"); printf("Client connection closed\n");
close(client_socket);
} }
#ifdef _WIN32
WSACleanup();
#endif
return 0; return 0;
} }

View file

@ -2,18 +2,19 @@
#include "ggml.h" #include "ggml.h"
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <sys/socket.h> #ifndef _WIN32
#include <sys/types.h> # include <sys/socket.h>
#include <netinet/in.h> # include <sys/types.h>
#include <netinet/tcp.h> # include <netinet/in.h>
#include <netdb.h> # include <netinet/tcp.h>
# include <netdb.h>
# include <unistd.h>
#endif
#include <string.h> #include <string.h>
#include <unistd.h>
#define UNUSED GGML_UNUSED #define UNUSED GGML_UNUSED
@ -24,15 +25,11 @@
#define GGML_PRINT_DEBUG(...) #define GGML_PRINT_DEBUG(...)
#endif #endif
// RPC data structures #ifdef _WIN32
using ssize_t = __int64;
#endif
struct sockfd { // RPC data structures
int fd;
sockfd(int fd) : fd(fd) {}
~sockfd() {
close(fd);
}
};
static ggml_guid_t ggml_backend_rpc_guid() { static ggml_guid_t ggml_backend_rpc_guid() {
static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03}; static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
@ -40,7 +37,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
} }
struct ggml_backend_rpc_buffer_type_context { struct ggml_backend_rpc_buffer_type_context {
std::shared_ptr<sockfd> sock; std::shared_ptr<socket_t> sock;
std::string name; std::string name;
size_t alignment; size_t alignment;
size_t max_size; size_t max_size;
@ -49,27 +46,47 @@ struct ggml_backend_rpc_buffer_type_context {
struct ggml_backend_rpc_context { struct ggml_backend_rpc_context {
std::string endpoint; std::string endpoint;
std::string name; std::string name;
std::shared_ptr<sockfd> sock; std::shared_ptr<socket_t> sock;
ggml_backend_buffer_type_t buft; ggml_backend_buffer_type_t buft;
}; };
struct ggml_backend_rpc_buffer_context { struct ggml_backend_rpc_buffer_context {
std::shared_ptr<sockfd> sock; std::shared_ptr<socket_t> sock;
std::unordered_map<ggml_backend_buffer_t, void *> base_cache; std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
uint64_t remote_ptr; uint64_t remote_ptr;
std::string name; std::string name;
}; };
// RPC helper functions // RPC helper functions
static std::shared_ptr<sockfd> socket_connect(const char * host, int port) { socket_t::~socket_t() {
struct sockaddr_in addr; #ifdef _WIN32
int sock = socket(AF_INET, SOCK_STREAM, 0); closesocket(this->fd);
if (sock < 0) { #else
close(this->fd);
#endif
}
static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
#ifdef _WIN32
if (fd == INVALID_SOCKET) {
return nullptr;
}
#else
if (fd < 0) {
return nullptr;
}
#endif
return std::make_shared<socket_t>(fd);
}
static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
struct sockaddr_in addr;
auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
auto sock_ptr = make_socket(sockfd);
if (sock_ptr == nullptr) {
return nullptr; return nullptr;
} }
auto sock_ptr = std::make_shared<sockfd>(sock);
// set TCP_NODELAY to disable Nagle's algorithm // set TCP_NODELAY to disable Nagle's algorithm
int flag = 1; int flag = 1;
int ret = setsockopt(sock_ptr->fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); int ret = setsockopt(sock_ptr->fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
@ -83,17 +100,17 @@ static std::shared_ptr<sockfd> socket_connect(const char * host, int port) {
fprintf(stderr, "Cannot resolve host '%s'\n", host); fprintf(stderr, "Cannot resolve host '%s'\n", host);
return nullptr; return nullptr;
} }
bcopy((char *)server->h_addr, (char *)&addr.sin_addr.s_addr, server->h_length); memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
return nullptr; return nullptr;
} }
return sock_ptr; return sock_ptr;
} }
static bool send_data(int sockfd, const void * data, size_t size) { static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
size_t bytes_sent = 0; size_t bytes_sent = 0;
while (bytes_sent < size) { while (bytes_sent < size) {
ssize_t n = send(sockfd, (const uint8_t *)data + bytes_sent, size - bytes_sent, 0); ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
if (n < 0) { if (n < 0) {
return false; return false;
} }
@ -102,10 +119,10 @@ static bool send_data(int sockfd, const void * data, size_t size) {
return true; return true;
} }
static bool recv_data(int sockfd, void * data, size_t size) { static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
size_t bytes_recv = 0; size_t bytes_recv = 0;
while (bytes_recv < size) { while (bytes_recv < size) {
ssize_t n = recv(sockfd, (uint8_t *)data + bytes_recv, size - bytes_recv, 0); ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
if (n <= 0) { if (n <= 0) {
return false; return false;
} }
@ -116,7 +133,7 @@ static bool recv_data(int sockfd, void * data, size_t size) {
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
static bool send_rpc_cmd(const std::shared_ptr<sockfd> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
uint8_t cmd_byte = cmd; uint8_t cmd_byte = cmd;
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
return false; return false;
@ -348,7 +365,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
return buffer; return buffer;
} }
static size_t get_alignment(const std::shared_ptr<sockfd> & sock) { static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes | // input serialization format: | 0 bytes |
std::vector<uint8_t> input; std::vector<uint8_t> input;
std::vector<uint8_t> output; std::vector<uint8_t> output;
@ -366,7 +383,7 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_
return buft_ctx->alignment; return buft_ctx->alignment;
} }
static size_t get_max_size(const std::shared_ptr<sockfd> & sock) { static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes | // input serialization format: | 0 bytes |
std::vector<uint8_t> input; std::vector<uint8_t> input;
std::vector<uint8_t> output; std::vector<uint8_t> output;
@ -522,6 +539,15 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
if (instances.find(endpoint) != instances.end()) { if (instances.find(endpoint) != instances.end()) {
return instances[endpoint]; return instances[endpoint];
} }
#ifdef _WIN32
{
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (res != 0) {
return nullptr;
}
}
#endif
GGML_PRINT_DEBUG("Connecting to %s\n", endpoint.c_str()); GGML_PRINT_DEBUG("Connecting to %s\n", endpoint.c_str());
// split the endpoint into host and port // split the endpoint into host and port
size_t pos = endpoint.find(":"); size_t pos = endpoint.find(":");
@ -565,7 +591,7 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
} }
static void get_device_memory(const std::shared_ptr<sockfd> & sock, size_t * free, size_t * total) { static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
// input serialization format: | 0 bytes | // input serialization format: | 0 bytes |
std::vector<uint8_t> input; std::vector<uint8_t> input;
std::vector<uint8_t> output; std::vector<uint8_t> output;
@ -779,7 +805,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
ggml_free(ctx); ggml_free(ctx);
} }
void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem) { void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
while (true) { while (true) {
uint8_t cmd; uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) { if (!recv_data(sockfd, &cmd, 1)) {

View file

@ -3,11 +3,35 @@
#include "ggml.h" #include "ggml.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include <string> #include <string>
#include <memory>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winsock2.h>
#endif
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
// cross-platform socket fd
#ifdef _WIN32
typedef SOCKET sockfd_t;
#else
typedef int sockfd_t;
#endif
// cross-platform socket
struct socket_t {
sockfd_t fd;
socket_t(sockfd_t fd) : fd(fd) {}
~socket_t();
};
// ggml_tensor is serialized into rpc_tensor // ggml_tensor is serialized into rpc_tensor
struct rpc_tensor { struct rpc_tensor {
uint64_t id; uint64_t id;
@ -50,7 +74,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total); GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total);
GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem); GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem);
#ifdef __cplusplus #ifdef __cplusplus
} }