From 4af9c8742cc2a657127d27723faee77a4b8f6f25 Mon Sep 17 00:00:00 2001 From: hongruichen Date: Sat, 18 May 2024 11:09:45 +0800 Subject: [PATCH] free buffer after client disconnect --- ggml-rpc.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index 4a9bfa52d..03d92d3b4 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #ifdef _WIN32 # define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX @@ -731,7 +732,7 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint // RPC server-side implementation -static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector & input, std::vector & output) { +static ggml_backend_buffer_t rpc_alloc_buffer(ggml_backend_t backend, const std::vector & input, std::vector & output) { // input serialization format: | size (8 bytes) | uint64_t size; memcpy(&size, input.data(), sizeof(size)); @@ -744,6 +745,7 @@ static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector output.resize(2*sizeof(uint64_t), 0); memcpy(output.data(), &remote_ptr, sizeof(remote_ptr)); memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size)); + return buffer; } static void rpc_get_alignment(ggml_backend_t backend, std::vector & output) { @@ -777,13 +779,14 @@ static void rpc_buffer_get_base(const std::vector & input, std::vector< memcpy(output.data(), &base_ptr, sizeof(base_ptr)); } -static void rpc_free_buffer(const std::vector & input) { +static ggml_backend_buffer_t rpc_free_buffer(const std::vector & input) { // input serialization format: | remote_ptr (8 bytes) | uint64_t remote_ptr; memcpy(&remote_ptr, input.data(), sizeof(remote_ptr)); GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr); ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr); ggml_backend_buffer_free(buffer); + return buffer; } static void rpc_buffer_clear(const std::vector & input) { @@ -917,6 +920,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector } static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) { + std::list allocated_buffers; while (true) { uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { @@ -934,7 +938,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre } switch (cmd) { case ALLOC_BUFFER: { - rpc_alloc_buffer(backend, input, output); + allocated_buffers.push_back(rpc_alloc_buffer(backend, input, output)); break; } case GET_ALIGNMENT: { @@ -950,7 +954,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre break; } case FREE_BUFFER: { - rpc_free_buffer(input); + allocated_buffers.remove(rpc_free_buffer(input)); break; } case BUFFER_CLEAR: { @@ -993,6 +997,10 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre break; } } + + for (auto buff: allocated_buffers) { + ggml_backend_buffer_free(buff); + } } void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {