add get_device_memory

This commit is contained in:
Radoslav Gerganov 2024-05-07 14:05:33 +03:00
parent 7a963c3087
commit 0b5e8a7183
3 changed files with 49 additions and 11 deletions

View file

@ -26,9 +26,7 @@ static ggml_backend_t create_backend() {
if (!backend) {
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
}
#endif
#ifdef GGML_USE_METAL
#elif GGML_USE_METAL
fprintf(stderr, "%s: using Metal backend\n", __func__);
backend = ggml_backend_metal_init();
if (!backend) {
@ -44,6 +42,16 @@ static ggml_backend_t create_backend() {
return backend;
}
static void get_backend_memory(size_t * free_mem, size_t * total_mem) {
#ifdef GGML_USE_CUDA
ggml_backend_cuda_get_device_memory(0, free_mem, total_mem);
#else
// TODO: implement for other backends
*free_mem = 1;
*total_mem = 1;
#endif
}
static int create_server_socket(const char * host, int port) {
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
@ -101,8 +109,10 @@ int main(int argc, char * argv[])
close(client_socket);
continue;
}
printf("Accepted client connection\n");
rpc_serve_client(backend, client_socket);
size_t free_mem, total_mem;
get_backend_memory(&free_mem, &total_mem);
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
rpc_serve_client(backend, client_socket, free_mem, total_mem);
printf("Client connection closed\n");
close(client_socket);
}

View file

@ -565,11 +565,31 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
}
static void get_device_memory(const std::shared_ptr<sockfd> & sock, size_t * free, size_t * total) {
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | free (8 bytes) | total (8 bytes) |
uint64_t free_mem;
memcpy(&free_mem, output.data(), sizeof(free_mem));
uint64_t total_mem;
memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
*free = free_mem;
*total = total_mem;
}
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total) {
UNUSED(endpoint);
// TODO: implement
*free = 1;
*total = 1;
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
if (backend == nullptr) {
*free = 0;
*total = 0;
return;
}
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
get_device_memory(ctx->sock, free, total);
}
// RPC server-side implementation
@ -759,7 +779,7 @@ static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t>
ggml_free(ctx);
}
void rpc_serve_client(ggml_backend_t backend, int sockfd) {
void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem) {
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
@ -816,6 +836,13 @@ void rpc_serve_client(ggml_backend_t backend, int sockfd) {
rpc_graph_compute(backend, input, output);
break;
}
case GET_DEVICE_MEMORY: {
// output serialization format: | free (8 bytes) | total (8 bytes) |
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &free_mem, sizeof(free_mem));
memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
break;
}
default: {
fprintf(stderr, "Unknown command: %d\n", cmd);
break;

View file

@ -37,6 +37,7 @@ enum rpc_cmd {
GET_TENSOR,
COPY_TENSOR,
GRAPH_COMPUTE,
GET_DEVICE_MEMORY,
};
#define GGML_RPC_MAX_SERVERS 16
@ -49,7 +50,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total);
GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd);
GGML_API GGML_CALL void rpc_serve_client(ggml_backend_t backend, int sockfd, size_t free_mem, size_t total_mem);
#ifdef __cplusplus
}