implement get_alignment and get_max_size
This commit is contained in:
parent
ef9be32791
commit
7a963c3087
2 changed files with 65 additions and 8 deletions
71
ggml-rpc.cpp
71
ggml-rpc.cpp
|
@ -42,6 +42,8 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
|||
struct ggml_backend_rpc_buffer_type_context {
|
||||
std::shared_ptr<sockfd> sock;
|
||||
std::string name;
|
||||
size_t alignment;
|
||||
size_t max_size;
|
||||
};
|
||||
|
||||
struct ggml_backend_rpc_context {
|
||||
|
@ -346,16 +348,40 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|||
return buffer;
|
||||
}
|
||||
|
||||
static size_t get_alignment(const std::shared_ptr<sockfd> & sock) {
|
||||
// input serialization format: | 0 bytes |
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||
// output serialization format: | alignment (8 bytes) |
|
||||
uint64_t alignment;
|
||||
memcpy(&alignment, output.data(), sizeof(alignment));
|
||||
return alignment;
|
||||
}
|
||||
|
||||
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
UNUSED(buft);
|
||||
// TODO: this is hardcoded for now but it should come from the remote backend
|
||||
return 128;
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||
return buft_ctx->alignment;
|
||||
}
|
||||
|
||||
static size_t get_max_size(const std::shared_ptr<sockfd> & sock) {
|
||||
// input serialization format: | 0 bytes |
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||
// output serialization format: | max_size (8 bytes) |
|
||||
uint64_t max_size;
|
||||
memcpy(&max_size, output.data(), sizeof(max_size));
|
||||
return max_size;
|
||||
}
|
||||
|
||||
GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
UNUSED(buft);
|
||||
// TODO: this is hardcoded for now but it should come from the remote backend
|
||||
return SIZE_MAX;
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||
return buft_ctx->max_size;
|
||||
}
|
||||
|
||||
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
|
@ -505,10 +531,13 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
|
|||
if (sock == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
size_t alignment = get_alignment(sock);
|
||||
size_t max_size = get_max_size(sock);
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
||||
/* .sock = */ sock,
|
||||
/* .name = */ "RPC" + std::to_string(sock->fd)
|
||||
/* .name = */ "RPC" + std::to_string(sock->fd),
|
||||
/* .alignment = */ alignment,
|
||||
/* .max_size = */ max_size
|
||||
};
|
||||
|
||||
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
||||
|
@ -560,6 +589,24 @@ static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t>
|
|||
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
|
||||
}
|
||||
|
||||
static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
||||
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
|
||||
// output serialization format: | alignment (8 bytes) |
|
||||
output.resize(sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &alignment, sizeof(alignment));
|
||||
}
|
||||
|
||||
static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & output) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
||||
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
|
||||
// output serialization format: | max_size (8 bytes) |
|
||||
output.resize(sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &max_size, sizeof(max_size));
|
||||
}
|
||||
|
||||
static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
||||
// input serialization format: | remote_ptr (8 bytes) |
|
||||
uint64_t remote_ptr;
|
||||
|
@ -733,6 +780,14 @@ void rpc_serve_client(ggml_backend_t backend, int sockfd) {
|
|||
rpc_alloc_buffer(backend, input, output);
|
||||
break;
|
||||
}
|
||||
case GET_ALIGNMENT: {
|
||||
rpc_get_alignment(backend, output);
|
||||
break;
|
||||
}
|
||||
case GET_MAX_SIZE: {
|
||||
rpc_get_max_size(backend, output);
|
||||
break;
|
||||
}
|
||||
case BUFFER_GET_BASE: {
|
||||
rpc_buffer_get_base(input, output);
|
||||
break;
|
||||
|
|
|
@ -28,6 +28,8 @@ struct rpc_tensor {
|
|||
// RPC commands
|
||||
enum rpc_cmd {
|
||||
ALLOC_BUFFER = 0,
|
||||
GET_ALIGNMENT,
|
||||
GET_MAX_SIZE,
|
||||
BUFFER_GET_BASE,
|
||||
FREE_BUFFER,
|
||||
BUFFER_CLEAR,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue