implement get_alignment and get_max_size
This commit is contained in:
parent
ef9be32791
commit
7a963c3087
2 changed files with 65 additions and 8 deletions
71
ggml-rpc.cpp
71
ggml-rpc.cpp
|
@ -42,6 +42,8 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
||||||
struct ggml_backend_rpc_buffer_type_context {
|
struct ggml_backend_rpc_buffer_type_context {
|
||||||
std::shared_ptr<sockfd> sock;
|
std::shared_ptr<sockfd> sock;
|
||||||
std::string name;
|
std::string name;
|
||||||
|
size_t alignment;
|
||||||
|
size_t max_size;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_rpc_context {
|
struct ggml_backend_rpc_context {
|
||||||
|
@ -346,16 +348,40 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t get_alignment(const std::shared_ptr<sockfd> & sock) {
|
||||||
|
// input serialization format: | 0 bytes |
|
||||||
|
std::vector<uint8_t> input;
|
||||||
|
std::vector<uint8_t> output;
|
||||||
|
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
|
||||||
|
GGML_ASSERT(status);
|
||||||
|
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||||
|
// output serialization format: | alignment (8 bytes) |
|
||||||
|
uint64_t alignment;
|
||||||
|
memcpy(&alignment, output.data(), sizeof(alignment));
|
||||||
|
return alignment;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||||
UNUSED(buft);
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||||
// TODO: this is hardcoded for now but it should come from the remote backend
|
return buft_ctx->alignment;
|
||||||
return 128;
|
}
|
||||||
|
|
||||||
|
static size_t get_max_size(const std::shared_ptr<sockfd> & sock) {
|
||||||
|
// input serialization format: | 0 bytes |
|
||||||
|
std::vector<uint8_t> input;
|
||||||
|
std::vector<uint8_t> output;
|
||||||
|
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
|
||||||
|
GGML_ASSERT(status);
|
||||||
|
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||||
|
// output serialization format: | max_size (8 bytes) |
|
||||||
|
uint64_t max_size;
|
||||||
|
memcpy(&max_size, output.data(), sizeof(max_size));
|
||||||
|
return max_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||||
UNUSED(buft);
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||||
// TODO: this is hardcoded for now but it should come from the remote backend
|
return buft_ctx->max_size;
|
||||||
return SIZE_MAX;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||||
|
@ -505,10 +531,13 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
|
||||||
if (sock == nullptr) {
|
if (sock == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
size_t alignment = get_alignment(sock);
|
||||||
|
size_t max_size = get_max_size(sock);
|
||||||
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
||||||
/* .sock = */ sock,
|
/* .sock = */ sock,
|
||||||
/* .name = */ "RPC" + std::to_string(sock->fd)
|
/* .name = */ "RPC" + std::to_string(sock->fd),
|
||||||
|
/* .alignment = */ alignment,
|
||||||
|
/* .max_size = */ max_size
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
||||||
|
@ -560,6 +589,24 @@ static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t>
|
||||||
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
|
memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
|
||||||
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||||
|
size_t alignment = ggml_backend_buft_get_alignment(buft);
|
||||||
|
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
|
||||||
|
// output serialization format: | alignment (8 bytes) |
|
||||||
|
output.resize(sizeof(uint64_t), 0);
|
||||||
|
memcpy(output.data(), &alignment, sizeof(alignment));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & output) {
|
||||||
|
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||||
|
size_t max_size = ggml_backend_buft_get_max_size(buft);
|
||||||
|
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
|
||||||
|
// output serialization format: | max_size (8 bytes) |
|
||||||
|
output.resize(sizeof(uint64_t), 0);
|
||||||
|
memcpy(output.data(), &max_size, sizeof(max_size));
|
||||||
|
}
|
||||||
|
|
||||||
static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
|
||||||
// input serialization format: | remote_ptr (8 bytes) |
|
// input serialization format: | remote_ptr (8 bytes) |
|
||||||
uint64_t remote_ptr;
|
uint64_t remote_ptr;
|
||||||
|
@ -733,6 +780,14 @@ void rpc_serve_client(ggml_backend_t backend, int sockfd) {
|
||||||
rpc_alloc_buffer(backend, input, output);
|
rpc_alloc_buffer(backend, input, output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case GET_ALIGNMENT: {
|
||||||
|
rpc_get_alignment(backend, output);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case GET_MAX_SIZE: {
|
||||||
|
rpc_get_max_size(backend, output);
|
||||||
|
break;
|
||||||
|
}
|
||||||
case BUFFER_GET_BASE: {
|
case BUFFER_GET_BASE: {
|
||||||
rpc_buffer_get_base(input, output);
|
rpc_buffer_get_base(input, output);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -28,6 +28,8 @@ struct rpc_tensor {
|
||||||
// RPC commands
|
// RPC commands
|
||||||
enum rpc_cmd {
|
enum rpc_cmd {
|
||||||
ALLOC_BUFFER = 0,
|
ALLOC_BUFFER = 0,
|
||||||
|
GET_ALIGNMENT,
|
||||||
|
GET_MAX_SIZE,
|
||||||
BUFFER_GET_BASE,
|
BUFFER_GET_BASE,
|
||||||
FREE_BUFFER,
|
FREE_BUFFER,
|
||||||
BUFFER_CLEAR,
|
BUFFER_CLEAR,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue