Address review comments

This commit is contained in:
Radoslav Gerganov 2024-04-30 15:27:54 +03:00
parent cddbf972c8
commit dfadd1a82c
2 changed files with 24 additions and 13 deletions

View file

@ -27,7 +27,7 @@
// RPC data structures
static ggml_guid_t ggml_backend_rpc_guid() {
static ggml_guid guid = { 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff};
static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
return &guid;
}
@ -45,6 +45,7 @@ struct ggml_backend_rpc_context {
struct ggml_backend_rpc_buffer_context {
int sockfd;
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
uint64_t remote_ptr;
std::string name;
};
@ -62,6 +63,7 @@ static int socket_connect(const char * host, int port) {
int flag = 1;
int ret = setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
if (ret < 0) {
close(sock);
return -1;
}
addr.sin_family = AF_INET;
@ -69,10 +71,12 @@ static int socket_connect(const char * host, int port) {
struct hostent * server = gethostbyname(host);
if (server == NULL) {
fprintf(stderr, "Cannot resolve host '%s'\n", host);
close(sock);
return -1;
}
bcopy((char *)server->h_addr, (char *)&addr.sin_addr.s_addr, server->h_length);
if (connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
close(sock);
return -1;
}
return sock;
@ -152,11 +156,10 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
}
GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
static std::unordered_map<ggml_backend_buffer_t, void *> cache;
if (cache.find(buffer) != cache.end()) {
return cache[buffer];
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
return ctx->base_cache[buffer];
}
// input serialization format: | remote_ptr (8 bytes) |
std::vector<uint8_t> input(sizeof(uint64_t), 0);
uint64_t remote_ptr = ctx->remote_ptr;
@ -169,7 +172,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
uint64_t base_ptr;
memcpy(&base_ptr, output.data(), sizeof(base_ptr));
void * base = reinterpret_cast<void *>(base_ptr);
cache[buffer] = base;
ctx->base_cache[buffer] = base;
return base;
}
@ -331,7 +334,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{buft_ctx->sockfd, remote_ptr, "RPC"},
new ggml_backend_rpc_buffer_context{buft_ctx->sockfd, {}, remote_ptr, "RPC"},
remote_size);
return buffer;
@ -343,6 +346,12 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_
return 128;
}
GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
UNUSED(buft);
// TODO: this is hardcoded for now but it should come from the remote backend
return SIZE_MAX;
}
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
UNUSED(buft);
return ggml_nbytes(tensor);
@ -361,7 +370,7 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
/* .get_name = */ ggml_backend_rpc_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_max_size = */ ggml_backend_rpc_get_max_size,
/* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
/* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend,
/* .is_host = */ NULL,
@ -475,7 +484,7 @@ static std::unordered_map<std::string, ggml_backend_t> instances;
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const std::string & endpoint) {
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
return ggml_backend_rpc_get_default_buffer_type(backend);
return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
}
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
@ -488,7 +497,9 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
std::string host = endpoint.substr(0, pos);
int port = std::stoi(endpoint.substr(pos + 1));
int sockfd = socket_connect(host.c_str(), port);
GGML_ASSERT(sockfd >= 0 && "failed to connect to the server");
if (sockfd < 0) {
return nullptr;
}
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .sockfd = */ sockfd,
@ -502,7 +513,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .name = */ "RPC",
/* .name = */ "RPC" + std::to_string(sockfd),
/* .sockfd = */ sockfd,
/* .buft = */ buft
};
@ -522,9 +533,9 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total) {
UNUSED(endpoint);
UNUSED(total);
// TODO: implement
*free = 1;
*total = 1;
}
// RPC server-side implementation

View file

@ -15729,7 +15729,7 @@ struct llama_context * llama_new_context_with_model(
for (auto & server : model->rpc_servers) {
ggml_backend_t backend = ggml_backend_rpc_init(server);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize RPC backend, endpoint: %s\n", __func__, server.c_str());
LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str());
llama_free(ctx);
return nullptr;
}