add get_device_memory
This commit is contained in:
parent
7a963c3087
commit
0b5e8a7183
3 changed files with 49 additions and 11 deletions
|
@ -26,9 +26,7 @@ static ggml_backend_t create_backend() {
|
|||
if (!backend) {
|
||||
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
#elif GGML_USE_METAL
|
||||
fprintf(stderr, "%s: using Metal backend\n", __func__);
|
||||
backend = ggml_backend_metal_init();
|
||||
if (!backend) {
|
||||
|
@ -44,6 +42,16 @@ static ggml_backend_t create_backend() {
|
|||
return backend;
|
||||
}
|
||||
|
||||
static void get_backend_memory(size_t * free_mem, size_t * total_mem) {
|
||||
#ifdef GGML_USE_CUDA
|
||||
ggml_backend_cuda_get_device_memory(0, free_mem, total_mem);
|
||||
#else
|
||||
// TODO: implement for other backends
|
||||
*free_mem = 1;
|
||||
*total_mem = 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
static int create_server_socket(const char * host, int port) {
|
||||
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sockfd < 0) {
|
||||
|
@ -101,8 +109,10 @@ int main(int argc, char * argv[])
|
|||
close(client_socket);
|
||||
continue;
|
||||
}
|
||||
printf("Accepted client connection\n");
|
||||
rpc_serve_client(backend, client_socket);
|
||||
size_t free_mem, total_mem;
|
||||
get_backend_memory(&free_mem, &total_mem);
|
||||
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
||||
rpc_serve_client(backend, client_socket, free_mem, total_mem);
|
||||
printf("Client connection closed\n");
|
||||
close(client_socket);
|
||||
}
|
||||
|
|
37
ggml-rpc.cpp
37
ggml-rpc.cpp
|
@ -565,11 +565,31 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
|
||||
}
|
||||
|
||||
static void get_device_memory(const std::shared_ptr<sockfd> & sock, size_t * free, size_t * total) {
|
||||
// input serialization format: | 0 bytes |
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||
uint64_t free_mem;
|
||||
memcpy(&free_mem, output.data(), sizeof(free_mem));
|
||||
uint64_t total_mem;
|
||||
memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
|
||||
*free = free_mem;
|
||||
*total = total_mem;
|
||||
}
|
||||
|
||||
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total) {
|
||||
UNUSED(endpoint);
|
||||
// TODO: implement
|
||||
*free = 1;
|
||||
*total = 1;
|
||||
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
|
||||
if (backend == nullptr) {
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
return;
|
||||
}
|
||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
||||
get_device_memory(ctx->sock, free, total);
|
||||
}
|
||||
|
||||
// RPC server-side implementation
|
||||
|
@ -759,7 +779,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
|
|||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
void rpc_serve_client(ggml_backend_t backend, int sockfd) {
|
||||
void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem) {
|
||||
while (true) {
|
||||
uint8_t cmd;
|
||||
if (!recv_data(sockfd, &cmd, 1)) {
|
||||
|
@ -816,6 +836,13 @@ void rpc_serve_client(ggml_backend_t backend, int sockfd) {
|
|||
rpc_graph_compute(backend, input, output);
|
||||
break;
|
||||
}
|
||||
case GET_DEVICE_MEMORY: {
|
||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||
output.resize(2*sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
||||
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
fprintf(stderr, "Unknown command: %d\n", cmd);
|
||||
break;
|
||||
|
|
|
@ -37,6 +37,7 @@ enum rpc_cmd {
|
|||
GET_TENSOR,
|
||||
COPY_TENSOR,
|
||||
GRAPH_COMPUTE,
|
||||
GET_DEVICE_MEMORY,
|
||||
};
|
||||
|
||||
#define GGML_RPC_MAX_SERVERS 16
|
||||
|
@ -49,7 +50,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
|
|||
|
||||
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total);
|
||||
|
||||
GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd);
|
||||
GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue