implement get_alignment and get_max_size

This commit is contained in:
Radoslav Gerganov 2024-05-07 11:32:18 +03:00
parent ef9be32791
commit 7a963c3087
2 changed files with 65 additions and 8 deletions

View file

@ -42,6 +42,8 @@ static ggml_guid_t ggml_backend_rpc_guid() {
struct ggml_backend_rpc_buffer_type_context { struct ggml_backend_rpc_buffer_type_context {
std::shared_ptr<sockfd> sock; std::shared_ptr<sockfd> sock;
std::string name; std::string name;
size_t alignment;
size_t max_size;
}; };
struct ggml_backend_rpc_context { struct ggml_backend_rpc_context {
@ -346,16 +348,40 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
return buffer; return buffer;
} }
static size_t get_alignment(const std::shared_ptr<sockfd> & sock) {
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | alignment (8 bytes) |
uint64_t alignment;
memcpy(&alignment, output.data(), sizeof(alignment));
return alignment;
}
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
UNUSED(buft); ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
// TODO: this is hardcoded for now but it should come from the remote backend return buft_ctx->alignment;
return 128; }
static size_t get_max_size(const std::shared_ptr<sockfd> & sock) {
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | max_size (8 bytes) |
uint64_t max_size;
memcpy(&max_size, output.data(), sizeof(max_size));
return max_size;
} }
GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) { GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
UNUSED(buft); ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
// TODO: this is hardcoded for now but it should come from the remote backend return buft_ctx->max_size;
return SIZE_MAX;
} }
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
@ -505,10 +531,13 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
if (sock == nullptr) { if (sock == nullptr) {
return nullptr; return nullptr;
} }
size_t alignment = get_alignment(sock);
size_t max_size = get_max_size(sock);
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .sock = */ sock, /* .sock = */ sock,
/* .name = */ "RPC" + std::to_string(sock->fd) /* .name = */ "RPC" + std::to_string(sock->fd),
/* .alignment = */ alignment,
/* .max_size = */ max_size
}; };
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
@ -560,6 +589,24 @@ static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t>
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size)); memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
} }
static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
size_t alignment = ggml_backend_buft_get_alignment(buft);
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
// output serialization format: | alignment (8 bytes) |
output.resize(sizeof(uint64_t), 0);
memcpy(output.data(), &alignment, sizeof(alignment));
}
static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & output) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
size_t max_size = ggml_backend_buft_get_max_size(buft);
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
// output serialization format: | max_size (8 bytes) |
output.resize(sizeof(uint64_t), 0);
memcpy(output.data(), &max_size, sizeof(max_size));
}
static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) { static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
// input serialization format: | remote_ptr (8 bytes) | // input serialization format: | remote_ptr (8 bytes) |
uint64_t remote_ptr; uint64_t remote_ptr;
@ -733,6 +780,14 @@ void rpc_serve_client(ggml_backend_t backend, int sockfd) {
rpc_alloc_buffer(backend, input, output); rpc_alloc_buffer(backend, input, output);
break; break;
} }
case GET_ALIGNMENT: {
rpc_get_alignment(backend, output);
break;
}
case GET_MAX_SIZE: {
rpc_get_max_size(backend, output);
break;
}
case BUFFER_GET_BASE: { case BUFFER_GET_BASE: {
rpc_buffer_get_base(input, output); rpc_buffer_get_base(input, output);
break; break;

View file

@ -28,6 +28,8 @@ struct rpc_tensor {
// RPC commands // RPC commands
enum rpc_cmd { enum rpc_cmd {
ALLOC_BUFFER = 0, ALLOC_BUFFER = 0,
GET_ALIGNMENT,
GET_MAX_SIZE,
BUFFER_GET_BASE, BUFFER_GET_BASE,
FREE_BUFFER, FREE_BUFFER,
BUFFER_CLEAR, BUFFER_CLEAR,