Address review comments
This commit is contained in:
parent
cddbf972c8
commit
dfadd1a82c
2 changed files with 24 additions and 13 deletions
35
ggml-rpc.cpp
35
ggml-rpc.cpp
|
@ -27,7 +27,7 @@
|
|||
// RPC data structures
|
||||
|
||||
static ggml_guid_t ggml_backend_rpc_guid() {
|
||||
static ggml_guid guid = { 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff};
|
||||
static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
|
||||
return &guid;
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,7 @@ struct ggml_backend_rpc_context {
|
|||
|
||||
struct ggml_backend_rpc_buffer_context {
|
||||
int sockfd;
|
||||
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
||||
uint64_t remote_ptr;
|
||||
std::string name;
|
||||
};
|
||||
|
@ -62,6 +63,7 @@ static int socket_connect(const char * host, int port) {
|
|||
int flag = 1;
|
||||
int ret = setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
|
||||
if (ret < 0) {
|
||||
close(sock);
|
||||
return -1;
|
||||
}
|
||||
addr.sin_family = AF_INET;
|
||||
|
@ -69,10 +71,12 @@ static int socket_connect(const char * host, int port) {
|
|||
struct hostent * server = gethostbyname(host);
|
||||
if (server == NULL) {
|
||||
fprintf(stderr, "Cannot resolve host '%s'\n", host);
|
||||
close(sock);
|
||||
return -1;
|
||||
}
|
||||
bcopy((char *)server->h_addr, (char *)&addr.sin_addr.s_addr, server->h_length);
|
||||
if (connect(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
|
||||
close(sock);
|
||||
return -1;
|
||||
}
|
||||
return sock;
|
||||
|
@ -152,11 +156,10 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
|
|||
}
|
||||
|
||||
GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
static std::unordered_map<ggml_backend_buffer_t, void *> cache;
|
||||
if (cache.find(buffer) != cache.end()) {
|
||||
return cache[buffer];
|
||||
}
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
||||
return ctx->base_cache[buffer];
|
||||
}
|
||||
// input serialization format: | remote_ptr (8 bytes) |
|
||||
std::vector<uint8_t> input(sizeof(uint64_t), 0);
|
||||
uint64_t remote_ptr = ctx->remote_ptr;
|
||||
|
@ -169,7 +172,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
|
|||
uint64_t base_ptr;
|
||||
memcpy(&base_ptr, output.data(), sizeof(base_ptr));
|
||||
void * base = reinterpret_cast<void *>(base_ptr);
|
||||
cache[buffer] = base;
|
||||
ctx->base_cache[buffer] = base;
|
||||
return base;
|
||||
}
|
||||
|
||||
|
@ -331,7 +334,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|||
|
||||
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
||||
ggml_backend_rpc_buffer_interface,
|
||||
new ggml_backend_rpc_buffer_context{buft_ctx->sockfd, remote_ptr, "RPC"},
|
||||
new ggml_backend_rpc_buffer_context{buft_ctx->sockfd, {}, remote_ptr, "RPC"},
|
||||
remote_size);
|
||||
|
||||
return buffer;
|
||||
|
@ -343,6 +346,12 @@ GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_
|
|||
return 128;
|
||||
}
|
||||
|
||||
GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
||||
UNUSED(buft);
|
||||
// TODO: this is hardcoded for now but it should come from the remote backend
|
||||
return SIZE_MAX;
|
||||
}
|
||||
|
||||
GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
UNUSED(buft);
|
||||
return ggml_nbytes(tensor);
|
||||
|
@ -361,7 +370,7 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|||
/* .get_name = */ ggml_backend_rpc_buffer_type_name,
|
||||
/* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||
/* .get_max_size = */ ggml_backend_rpc_get_max_size,
|
||||
/* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
|
||||
/* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend,
|
||||
/* .is_host = */ NULL,
|
||||
|
@ -475,7 +484,7 @@ static std::unordered_map<std::string, ggml_backend_t> instances;
|
|||
|
||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const std::string & endpoint) {
|
||||
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
|
||||
return ggml_backend_rpc_get_default_buffer_type(backend);
|
||||
return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
|
||||
}
|
||||
|
||||
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
|
||||
|
@ -488,7 +497,9 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
|
|||
std::string host = endpoint.substr(0, pos);
|
||||
int port = std::stoi(endpoint.substr(pos + 1));
|
||||
int sockfd = socket_connect(host.c_str(), port);
|
||||
GGML_ASSERT(sockfd >= 0 && "failed to connect to the server");
|
||||
if (sockfd < 0) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
||||
/* .sockfd = */ sockfd,
|
||||
|
@ -502,7 +513,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const std::string & endpoint) {
|
|||
|
||||
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
||||
/* .endpoint = */ endpoint,
|
||||
/* .name = */ "RPC",
|
||||
/* .name = */ "RPC" + std::to_string(sockfd),
|
||||
/* .sockfd = */ sockfd,
|
||||
/* .buft = */ buft
|
||||
};
|
||||
|
@ -522,9 +533,9 @@ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|||
|
||||
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const std::string & endpoint, size_t * free, size_t * total) {
|
||||
UNUSED(endpoint);
|
||||
UNUSED(total);
|
||||
// TODO: implement
|
||||
*free = 1;
|
||||
*total = 1;
|
||||
}
|
||||
|
||||
// RPC server-side implementation
|
||||
|
|
|
@ -15729,7 +15729,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
for (auto & server : model->rpc_servers) {
|
||||
ggml_backend_t backend = ggml_backend_rpc_init(server);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize RPC backend, endpoint: %s\n", __func__, server.c_str());
|
||||
LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str());
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue