Refactor: Moves cuda graph update check to separate function.

This commit is contained in:
Andreas Kieslinger 2025-01-08 13:18:18 +00:00
parent ba0533100d
commit 22c2429496

View file

@ -2337,6 +2337,36 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
}
#endif
#ifdef USE_CUDA_GRAPH
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) {
if (cuda_ctx->cuda_graph->instance == nullptr) {
cuda_graph_update_required = true;
}
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
return cuda_graph_update_required;
}
#endif
#ifdef USE_CUDA_GRAPH
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
@ -2398,28 +2428,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
}
if (use_cuda_graph) {
if (cuda_ctx->cuda_graph->instance == nullptr) {
cuda_graph_update_required = true;
}
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required);
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->updated_kernel_arg.clear();