From eb3ea69850d5e2838f5c2eda4feb11bd43bdaef1 Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger Date: Wed, 8 Jan 2025 13:47:14 +0000 Subject: [PATCH] Refactor: Moves cuda graph maintenance (update or adjusting copy parameters) to separate function for improved readability. --- ggml/src/ggml-cuda/ggml-cuda.cu | 93 ++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 43 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 40a039274..de5709f06 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2337,6 +2337,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } #endif + +#ifdef USE_CUDA_GRAPH +static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) { + + if (cuda_graph_update_required) { + // Extract nodes from graph + // First call with null argument gets number of nodes in graph + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); + // Subsequent call with non-null argument gets nodes + cuda_ctx->cuda_graph->nodes.clear(); + cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); + cuda_ctx->cuda_graph->params.clear(); + cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); + if (cuda_ctx->cuda_graph->num_nodes > 0) { + CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); + + // Loop over nodes, and extract kernel parameters from each node + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + cudaGraphNodeType node_type; + CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); + if (node_type == cudaGraphNodeTypeKernel) { + cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime + if (stat == cudaErrorInvalidDeviceFunction) { + // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. + // We don't need to update blas nodes, so clear error and move on. + cudaGetLastError(); + } else { + GGML_ASSERT(stat == cudaSuccess); + } + } + } + } + } else { + // One of the arguments to the copy kernel is updated for each token, hence we need to + // replace that argument with the updated value in the CUDA graph + // on update steps, the live parameters will already be captured + int k = 0; + for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { + if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { + char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); + cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; + CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); + } + } + } +} +#endif + + #ifdef USE_CUDA_GRAPH static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) { @@ -2571,49 +2620,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, } // Perform update to graph (if required for this token), and change copy parameter (required for every token) - - if (cuda_graph_update_required) { - // Extract nodes from graph - // First call with null argument gets number of nodes in graph - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes)); - // Subsequent call with non-null argument gets nodes - cuda_ctx->cuda_graph->nodes.clear(); - cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes); - cuda_ctx->cuda_graph->params.clear(); - cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes); - if (cuda_ctx->cuda_graph->num_nodes > 0) { - CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes)); - - // Loop over nodes, and extract kernel parameters from each node - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - cudaGraphNodeType node_type; - CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type)); - if (node_type == cudaGraphNodeTypeKernel) { - cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime - if (stat == cudaErrorInvalidDeviceFunction) { - // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. - // We don't need to update blas nodes, so clear error and move on. - cudaGetLastError(); - } else { - GGML_ASSERT(stat == cudaSuccess); - } - } - } - } - } - - // One of the arguments to the copy kernel is updated for each token, hence we need to - // replace that argument with the updated value in the CUDA graph - if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured - int k = 0; - for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) { - if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) { - char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++); - cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr; - CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i])); - } - } - } + maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required); // Update graph executable update_cuda_graph_executable(cuda_ctx);