diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index b2bbd7c2f..f809aec33 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -26,9 +26,7 @@ static ggml_backend_t create_backend() { if (!backend) { fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); } -#endif - -#ifdef GGML_USE_METAL +#elif GGML_USE_METAL fprintf(stderr, "%s: using Metal backend\n", __func__); backend = ggml_backend_metal_init(); if (!backend) { @@ -44,6 +42,16 @@ static ggml_backend_t create_backend() { return backend; } +static void get_backend_memory(size_t * free_mem, size_t * total_mem) { +#ifdef GGML_USE_CUDA + ggml_backend_cuda_get_device_memory(0, free_mem, total_mem); +#else + // TODO: implement for other backends + *free_mem = 1; + *total_mem = 1; +#endif +} + static int create_server_socket(const char * host, int port) { int sockfd = socket(AF_INET, SOCK_STREAM, 0); if (sockfd < 0) { @@ -101,8 +109,10 @@ int main(int argc, char * argv[]) close(client_socket); continue; } - printf("Accepted client connection\n"); - rpc_serve_client(backend, client_socket); + size_t free_mem, total_mem; + get_backend_memory(&free_mem, &total_mem); + printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem); + rpc_serve_client(backend, client_socket, free_mem, total_mem); printf("Client connection closed\n"); close(client_socket); } diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index 323d7690a..1c21c44a0 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -565,11 +565,31 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid()); } +static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { + // input serialization format: | 0 bytes | + std::vector input; + std::vector output; + bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); + // output serialization format: | free (8 bytes) | total (8 bytes) | + uint64_t free_mem; + memcpy(&free_mem, output.data(), sizeof(free_mem)); + uint64_t total_mem; + memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem)); + *free = free_mem; + *total = total_mem; +} + GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total) { - UNUSED(endpoint); - // TODO: implement - *free = 1; - *total = 1; + ggml_backend_t backend = ggml_backend_rpc_init(endpoint); + if (backend == nullptr) { + *free = 0; + *total = 0; + return; + } + ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; + get_device_memory(ctx->sock, free, total); } // RPC server-side implementation @@ -759,7 +779,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector ggml_free(ctx); } -void rpc_serve_client(ggml_backend_t backend, int sockfd) { +void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem) { while (true) { uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { @@ -816,6 +836,13 @@ void rpc_serve_client(ggml_backend_t backend, int sockfd) { rpc_graph_compute(backend, input, output); break; } + case GET_DEVICE_MEMORY: { + // output serialization format: | free (8 bytes) | total (8 bytes) | + output.resize(2*sizeof(uint64_t), 0); + memcpy(output.data(), &free_mem, sizeof(free_mem)); + memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem)); + break; + } default: { fprintf(stderr, "Unknown command: %d\n", cmd); break; diff --git a/ggml-rpc.h b/ggml-rpc.h index a98d789f5..a7dfaef71 100644 --- a/ggml-rpc.h +++ b/ggml-rpc.h @@ -37,6 +37,7 @@ enum rpc_cmd { GET_TENSOR, COPY_TENSOR, GRAPH_COMPUTE, + GET_DEVICE_MEMORY, }; #define GGML_RPC_MAX_SERVERS 16 @@ -49,7 +50,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total); -GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd); +GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem); #ifdef __cplusplus }