From dfadd1a82c471e67d06cfa0c1ff3056151bad894 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 30 Apr 2024 15:27:54 +0300 Subject: [PATCH] Address review comments --- ggml-rpc.cpp | 35 +++++++++++++++++++++++------------ llama.cpp | 2 +- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index ec7c35dbd..36dd671f3 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -27,7 +27,7 @@ // RPC data structures static ggml_guid_t ggml_backend_rpc_guid() { - static ggml_guid guid = { 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff}; + static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03}; return &guid; } @@ -45,6 +45,7 @@ struct ggml_backend_rpc_context { struct ggml_backend_rpc_buffer_context { int sockfd; + std::unordered_map base_cache; uint64_t remote_ptr; std::string name; }; @@ -62,6 +63,7 @@ static int socket_connect(const char * host, int port) { int flag = 1; int ret = setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); if (ret < 0) { + close(sock); return -1; } addr.sin_family = AF_INET; @@ -69,10 +71,12 @@ static int socket_connect(const char * host, int port) { struct hostent * server = gethostbyname(host); if (server == NULL) { fprintf(stderr, "Cannot resolve host '%s'\n", host); + close(sock); return -1; } bcopy((char *)server->h_addr, (char *)&addr.sin_addr.s_addr, server->h_length); if (connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + close(sock); return -1; } return sock; @@ -152,11 +156,10 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t } GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { - static std::unordered_map cache; - if (cache.find(buffer) != cache.end()) { - return cache[buffer]; - } ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; + if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) { + return ctx->base_cache[buffer]; + } // input serialization format: | remote_ptr (8 bytes) | std::vector input(sizeof(uint64_t), 0); uint64_t remote_ptr = ctx->remote_ptr; @@ -169,7 +172,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b uint64_t base_ptr; memcpy(&base_ptr, output.data(), sizeof(base_ptr)); void * base = reinterpret_cast(base_ptr); - cache[buffer] = base; + ctx->base_cache[buffer] = base; return base; } @@ -331,7 +334,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{buft_ctx->sockfd, remote_ptr, "RPC"}, + new ggml_backend_rpc_buffer_context{buft_ctx->sockfd, {}, remote_ptr, "RPC"}, remote_size); return buffer; @@ -343,6 +346,12 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_ return 128; } +GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { + UNUSED(buft); + // TODO: this is hardcoded for now but it should come from the remote backend + return SIZE_MAX; +} + GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { UNUSED(buft); return ggml_nbytes(tensor); @@ -361,7 +370,7 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { /* .get_name = */ ggml_backend_rpc_buffer_type_name, /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_max_size = */ ggml_backend_rpc_get_max_size, /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size, /* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend, /* .is_host = */ NULL, @@ -475,7 +484,7 @@ static std::unordered_map instances; GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const std::string & endpoint) { ggml_backend_t backend = ggml_backend_rpc_init(endpoint); - return ggml_backend_rpc_get_default_buffer_type(backend); + return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr; } GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) { @@ -488,7 +497,9 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) { std::string host = endpoint.substr(0, pos); int port = std::stoi(endpoint.substr(pos + 1)); int sockfd = socket_connect(host.c_str(), port); - GGML_ASSERT(sockfd >= 0 && "failed to connect to the server"); + if (sockfd < 0) { + return nullptr; + } ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { /* .sockfd = */ sockfd, @@ -502,7 +513,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { /* .endpoint = */ endpoint, - /* .name = */ "RPC", + /* .name = */ "RPC" + std::to_string(sockfd), /* .sockfd = */ sockfd, /* .buft = */ buft }; @@ -522,9 +533,9 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total) { UNUSED(endpoint); - UNUSED(total); // TODO: implement *free = 1; + *total = 1; } // RPC server-side implementation diff --git a/llama.cpp b/llama.cpp index cf60f5ad6..610cba5c6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15729,7 +15729,7 @@ struct llama_context * llama_new_context_with_model( for (auto & server : model->rpc_servers) { ggml_backend_t backend = ggml_backend_rpc_init(server); if (backend == nullptr) { - LLAMA_LOG_ERROR("%s: failed to initialize RPC backend, endpoint: %s\n", __func__, server.c_str()); + LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str()); llama_free(ctx); return nullptr; }