Added init tensor calling code

This commit is contained in:
matt23654 2024-12-31 21:56:51 +00:00
parent 0827b2c1da
commit 7aad6cbda6

View file

@ -93,9 +93,18 @@ enum rpc_cmd {
RPC_CMD_COPY_TENSOR, RPC_CMD_COPY_TENSOR,
RPC_CMD_GRAPH_COMPUTE, RPC_CMD_GRAPH_COMPUTE,
RPC_CMD_GET_DEVICE_MEMORY, RPC_CMD_GET_DEVICE_MEMORY,
RPC_CMD_INIT_TENSOR,
RPC_CMD_COUNT, RPC_CMD_COUNT,
}; };
struct rpc_msg_init_tensor_req {
rpc_tensor tensor;
};
struct rpc_msg_init_tensor_rsp {
uint8_t result; // success/failure
};
struct rpc_msg_alloc_buffer_req { struct rpc_msg_alloc_buffer_req {
uint64_t size; uint64_t size;
}; };
@ -461,10 +470,18 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
} }
static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
UNUSED(buffer); //UNUSED(buffer);
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
if (ggml_is_quantized(tensor->type)) { if (ggml_is_quantized(tensor->type)) {
// TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor"); //GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
rpc_msg_init_tensor_req request;
request.tensor = serialize_tensor(tensor);
//rpc_msg_init_tensor_rsp response;
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
GGML_ASSERT(status);
} }
} }
@ -757,6 +774,7 @@ public:
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response); bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
private: private:
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@ -905,6 +923,35 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
return true; return true;
} }
bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
struct ggml_context * ctx = ggml_init(params);
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
if (tensor == nullptr) {
printf("Null tensor\n");
ggml_free(ctx);
return false;
}
printf("about to call buffer\n");
//ggml_backend_init_tensor
// Call the backend's buffer_init_tensor function
ggml_backend_buffer_t buffer = tensor->buffer;
if (buffer && buffer->iface.init_tensor) {
printf("Calling buffer iface function\n");
buffer->iface.init_tensor(buffer, tensor);
}
ggml_free(ctx);
return true;
}
bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) { bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
struct ggml_init_params params { struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(), /*.mem_size =*/ ggml_tensor_overhead(),
@ -1133,6 +1180,19 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
} }
break; break;
} }
case RPC_CMD_INIT_TENSOR: {
rpc_msg_init_tensor_req request;
if (!recv_msg(sockfd, &request,sizeof(request))) {
return;
}
if (!server.init_tensor(request)) {
return;
}
if (!send_msg(sockfd, nullptr, 0)) {
return;
}
break;
}
case RPC_CMD_GET_TENSOR: { case RPC_CMD_GET_TENSOR: {
rpc_msg_get_tensor_req request; rpc_msg_get_tensor_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) { if (!recv_msg(sockfd, &request, sizeof(request))) {