From 7a963c3087305b9f8c2773aa87ad16039ffc6519 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 7 May 2024 11:32:18 +0300 Subject: [PATCH] implement get_alignment and get_max_size --- ggml-rpc.cpp | 71 ++++++++++++++++++++++++++++++++++++++++++++++------ ggml-rpc.h | 2 ++ 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index 3ff5700aa..323d7690a 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -42,6 +42,8 @@ static ggml_guid_t ggml_backend_rpc_guid() { struct ggml_backend_rpc_buffer_type_context { std::shared_ptr sock; std::string name; + size_t alignment; + size_t max_size; }; struct ggml_backend_rpc_context { @@ -346,16 +348,40 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer return buffer; } +static size_t get_alignment(const std::shared_ptr & sock) { + // input serialization format: | 0 bytes | + std::vector input; + std::vector output; + bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == sizeof(uint64_t)); + // output serialization format: | alignment (8 bytes) | + uint64_t alignment; + memcpy(&alignment, output.data(), sizeof(alignment)); + return alignment; +} + GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - UNUSED(buft); - // TODO: this is hardcoded for now but it should come from the remote backend - return 128; + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->alignment; +} + +static size_t get_max_size(const std::shared_ptr & sock) { + // input serialization format: | 0 bytes | + std::vector input; + std::vector output; + bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output); + GGML_ASSERT(status); + GGML_ASSERT(output.size() == sizeof(uint64_t)); + // output serialization format: | max_size (8 bytes) | + uint64_t max_size; + memcpy(&max_size, output.data(), sizeof(max_size)); + return max_size; } GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { - UNUSED(buft); - // TODO: this is hardcoded for now but it should come from the remote backend - return SIZE_MAX; + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; + return buft_ctx->max_size; } GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { @@ -505,10 +531,13 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) { if (sock == nullptr) { return nullptr; } - + size_t alignment = get_alignment(sock); + size_t max_size = get_max_size(sock); ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { /* .sock = */ sock, - /* .name = */ "RPC" + std::to_string(sock->fd) + /* .name = */ "RPC" + std::to_string(sock->fd), + /* .alignment = */ alignment, + /* .max_size = */ max_size }; ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { @@ -560,6 +589,24 @@ static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size)); } +static void rpc_get_alignment(ggml_backend_t backend, std::vector & output) { + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + size_t alignment = ggml_backend_buft_get_alignment(buft); + GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment); + // output serialization format: | alignment (8 bytes) | + output.resize(sizeof(uint64_t), 0); + memcpy(output.data(), &alignment, sizeof(alignment)); +} + +static void rpc_get_max_size(ggml_backend_t backend, std::vector & output) { + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + size_t max_size = ggml_backend_buft_get_max_size(buft); + GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size); + // output serialization format: | max_size (8 bytes) | + output.resize(sizeof(uint64_t), 0); + memcpy(output.data(), &max_size, sizeof(max_size)); +} + static void rpc_buffer_get_base(const std::vector & input, std::vector & output) { // input serialization format: | remote_ptr (8 bytes) | uint64_t remote_ptr; @@ -733,6 +780,14 @@ void rpc_serve_client(ggml_backend_t backend, int sockfd) { rpc_alloc_buffer(backend, input, output); break; } + case GET_ALIGNMENT: { + rpc_get_alignment(backend, output); + break; + } + case GET_MAX_SIZE: { + rpc_get_max_size(backend, output); + break; + } case BUFFER_GET_BASE: { rpc_buffer_get_base(input, output); break; diff --git a/ggml-rpc.h b/ggml-rpc.h index 6c1f6d091..a98d789f5 100644 --- a/ggml-rpc.h +++ b/ggml-rpc.h @@ -28,6 +28,8 @@ struct rpc_tensor { // RPC commands enum rpc_cmd { ALLOC_BUFFER = 0, + GET_ALIGNMENT, + GET_MAX_SIZE, BUFFER_GET_BASE, FREE_BUFFER, BUFFER_CLEAR,