diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 715606f32..40a039274 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2337,6 +2337,36 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } #endif +#ifdef USE_CUDA_GRAPH +static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) { + + if (cuda_ctx->cuda_graph->instance == nullptr) { + cuda_graph_update_required = true; + } + + // Check if the graph size has changed + if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { + cuda_graph_update_required = true; + cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + } + + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + for (int i = 0; i < cgraph->n_nodes; i++) { + bool has_matching_properties = true; + if (!cuda_graph_update_required) { + has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + if (!has_matching_properties) { + cuda_graph_update_required = true; + } + set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + } + + return cuda_graph_update_required; +} +#endif + #ifdef USE_CUDA_GRAPH static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { @@ -2398,28 +2428,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, } if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { - cuda_graph_update_required = true; - } - - // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { - cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); - } - - // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token - for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = true; - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); - } - if (!has_matching_properties) { - cuda_graph_update_required = true; - } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); - } + cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required); // Loop over nodes in GGML graph to obtain info needed for CUDA graph cuda_ctx->cuda_graph->updated_kernel_arg.clear();