Refactor: Moves cuda graph maintenance (update or adjusting copy parameters) to separate function for improved readability.
This commit is contained in:
parent
22c2429496
commit
eb3ea69850
1 changed files with 50 additions and 43 deletions
|
@ -2337,6 +2337,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
|
||||||
|
|
||||||
|
if (cuda_graph_update_required) {
|
||||||
|
// Extract nodes from graph
|
||||||
|
// First call with null argument gets number of nodes in graph
|
||||||
|
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
|
||||||
|
// Subsequent call with non-null argument gets nodes
|
||||||
|
cuda_ctx->cuda_graph->nodes.clear();
|
||||||
|
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
|
||||||
|
cuda_ctx->cuda_graph->params.clear();
|
||||||
|
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
|
||||||
|
if (cuda_ctx->cuda_graph->num_nodes > 0) {
|
||||||
|
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
|
||||||
|
|
||||||
|
// Loop over nodes, and extract kernel parameters from each node
|
||||||
|
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
||||||
|
cudaGraphNodeType node_type;
|
||||||
|
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
|
||||||
|
if (node_type == cudaGraphNodeTypeKernel) {
|
||||||
|
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
|
||||||
|
if (stat == cudaErrorInvalidDeviceFunction) {
|
||||||
|
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||||
|
// We don't need to update blas nodes, so clear error and move on.
|
||||||
|
cudaGetLastError();
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(stat == cudaSuccess);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// One of the arguments to the copy kernel is updated for each token, hence we need to
|
||||||
|
// replace that argument with the updated value in the CUDA graph
|
||||||
|
// on update steps, the live parameters will already be captured
|
||||||
|
int k = 0;
|
||||||
|
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
||||||
|
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
|
||||||
|
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
|
||||||
|
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
|
||||||
|
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#ifdef USE_CUDA_GRAPH
|
#ifdef USE_CUDA_GRAPH
|
||||||
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) {
|
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) {
|
||||||
|
|
||||||
|
@ -2571,49 +2620,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
||||||
|
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
|
||||||
if (cuda_graph_update_required) {
|
|
||||||
// Extract nodes from graph
|
|
||||||
// First call with null argument gets number of nodes in graph
|
|
||||||
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
|
|
||||||
// Subsequent call with non-null argument gets nodes
|
|
||||||
cuda_ctx->cuda_graph->nodes.clear();
|
|
||||||
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
|
|
||||||
cuda_ctx->cuda_graph->params.clear();
|
|
||||||
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
|
|
||||||
if (cuda_ctx->cuda_graph->num_nodes > 0) {
|
|
||||||
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
|
|
||||||
|
|
||||||
// Loop over nodes, and extract kernel parameters from each node
|
|
||||||
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
|
||||||
cudaGraphNodeType node_type;
|
|
||||||
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
|
|
||||||
if (node_type == cudaGraphNodeTypeKernel) {
|
|
||||||
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
|
|
||||||
if (stat == cudaErrorInvalidDeviceFunction) {
|
|
||||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
|
||||||
// We don't need to update blas nodes, so clear error and move on.
|
|
||||||
cudaGetLastError();
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(stat == cudaSuccess);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// One of the arguments to the copy kernel is updated for each token, hence we need to
|
|
||||||
// replace that argument with the updated value in the CUDA graph
|
|
||||||
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
|
|
||||||
int k = 0;
|
|
||||||
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
|
||||||
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
|
|
||||||
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
|
|
||||||
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
|
|
||||||
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update graph executable
|
// Update graph executable
|
||||||
update_cuda_graph_executable(cuda_ctx);
|
update_cuda_graph_executable(cuda_ctx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue