free buffer after client disconnect

This commit is contained in:
hongruichen 2024-05-18 11:09:45 +08:00 committed by Hongrui Chen
parent 511182eabb
commit 4af9c8742c

View file

@ -8,6 +8,7 @@
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <list>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
@ -731,7 +732,7 @@ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint
// RPC server-side implementation
static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
static ggml_backend_buffer_t rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
// input serialization format: | size (8 bytes) |
uint64_t size;
memcpy(&size, input.data(), sizeof(size));
@ -744,6 +745,7 @@ static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t>
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
return buffer;
}
static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
@ -777,13 +779,14 @@ static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<
memcpy(output.data(), &base_ptr, sizeof(base_ptr));
}
static void rpc_free_buffer(const std::vector<uint8_t> & input) {
static ggml_backend_buffer_t rpc_free_buffer(const std::vector<uint8_t> & input) {
// input serialization format: | remote_ptr (8 bytes) |
uint64_t remote_ptr;
memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
ggml_backend_buffer_free(buffer);
return buffer;
}
static void rpc_buffer_clear(const std::vector<uint8_t> & input) {
@ -917,6 +920,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
}
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
std::list<ggml_backend_buffer_t> allocated_buffers;
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
@ -934,7 +938,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
}
switch (cmd) {
case ALLOC_BUFFER: {
rpc_alloc_buffer(backend, input, output);
allocated_buffers.push_back(rpc_alloc_buffer(backend, input, output));
break;
}
case GET_ALIGNMENT: {
@ -950,7 +954,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
break;
}
case FREE_BUFFER: {
rpc_free_buffer(input);
allocated_buffers.remove(rpc_free_buffer(input));
break;
}
case BUFFER_CLEAR: {
@ -993,6 +997,10 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
break;
}
}
for (auto buff: allocated_buffers) {
ggml_backend_buffer_free(buff);
}
}
void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {