Cleaned up and improved type/error handling.
This commit is contained in:
parent
c47dc70b58
commit
1948ae8491
1 changed files with 21 additions and 34 deletions
|
@ -110,9 +110,9 @@ struct rpc_msg_init_tensor_req {
|
|||
rpc_tensor tensor;
|
||||
};
|
||||
|
||||
struct rpc_msg_init_tensor_rsp {
|
||||
uint8_t result; // success/failure
|
||||
};
|
||||
//struct rpc_msg_init_tensor_rsp {
|
||||
// uint8_t result; // success/failure
|
||||
//};
|
||||
|
||||
struct rpc_msg_alloc_buffer_req {
|
||||
uint64_t size;
|
||||
|
@ -479,16 +479,15 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|||
}
|
||||
|
||||
static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||
//UNUSED(buffer);
|
||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||
|
||||
// CUDA backend on the server pads everything to 512 due to CUDA limitations.
|
||||
// Due to bandwidth constraints, we only call the server init tensor functions if necessary.
|
||||
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0)) {
|
||||
// TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
|
||||
//GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
|
||||
rpc_msg_init_tensor_req request;
|
||||
|
||||
request.tensor = serialize_tensor(tensor);
|
||||
|
||||
//rpc_msg_init_tensor_rsp response;
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
|
||||
GGML_ASSERT(status);
|
||||
}
|
||||
|
@ -603,11 +602,13 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
|||
}
|
||||
|
||||
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
||||
// See comments in init_tensor.
|
||||
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0)) {
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||
auto sock = get_socket(buft_ctx->endpoint);
|
||||
|
||||
rpc_msg_get_alloc_size_req request;
|
||||
|
||||
request.tensor = serialize_tensor(tensor);
|
||||
|
||||
rpc_msg_get_alloc_size_rsp response;
|
||||
|
@ -812,41 +813,30 @@ private:
|
|||
};
|
||||
|
||||
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
|
||||
ggml_backend_buffer_type_t buft;
|
||||
struct ggml_init_params params {
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
||||
|
||||
if (tensor == nullptr) {
|
||||
printf("Got nullptr\n");
|
||||
fprintf(stderr,"Null tensor pointer passed to server get_alloc_size function.\n");
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
printf("Getting buft\n");
|
||||
|
||||
//ggml_backend_buffer_get_alloc_size(tensor->buffer,tensor)
|
||||
|
||||
//if (tensor->buffer == nullptr) {
|
||||
// printf("Got null buffer\n");
|
||||
// response.alloc_size = 0;
|
||||
// ggml_free(ctx);
|
||||
// return true;
|
||||
//}
|
||||
if (tensor->buffer == nullptr) {
|
||||
//No buffer allocated.
|
||||
buft = ggml_backend_get_default_buffer_type(backend);
|
||||
} else {
|
||||
buft = tensor->buffer->buft;
|
||||
}
|
||||
|
||||
response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
|
||||
// Call the backend's buffer_type_get_alloc_size function
|
||||
//ggml_backend_buffer_type_t buft = tensor->buffer->buft;
|
||||
//if (buft && buft->iface.get_alloc_size) {
|
||||
// printf("Called buffer type get alloc size\n");
|
||||
// response.alloc_size = buft->iface.get_alloc_size(buft, tensor);
|
||||
//} else {
|
||||
// printf("Called ggml_nbytes");
|
||||
// response.alloc_size = ggml_nbytes(tensor);
|
||||
//}
|
||||
|
||||
ggml_free(ctx);
|
||||
return true;
|
||||
|
@ -996,20 +986,17 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
|
|||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
||||
if (tensor == nullptr) {
|
||||
printf("Null tensor\n");
|
||||
fprintf(stderr,"Null tensor pointer passed to server init_tensor function.\n");
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
printf("about to call buffer\n");
|
||||
|
||||
//ggml_backend_init_tensor
|
||||
|
||||
// Call the backend's buffer_init_tensor function
|
||||
ggml_backend_buffer_t buffer = tensor->buffer;
|
||||
if (buffer && buffer->iface.init_tensor) {
|
||||
printf("Calling buffer iface function\n");
|
||||
buffer->iface.init_tensor(buffer, tensor);
|
||||
} else {
|
||||
fprintf(stderr,"Null buffer for tensor passed to init_tensor function\n");
|
||||
}
|
||||
|
||||
ggml_free(ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue