Refactor: Moves cuda graph update check to separate function.
This commit is contained in:
parent
ba0533100d
commit
22c2429496
1 changed files with 31 additions and 22 deletions
|
@ -2337,6 +2337,36 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required) {
|
||||||
|
|
||||||
|
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
||||||
|
cuda_graph_update_required = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the graph size has changed
|
||||||
|
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
|
||||||
|
cuda_graph_update_required = true;
|
||||||
|
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||||
|
// and store properties to allow this comparison for the next token
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
bool has_matching_properties = true;
|
||||||
|
if (!cuda_graph_update_required) {
|
||||||
|
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||||
|
}
|
||||||
|
if (!has_matching_properties) {
|
||||||
|
cuda_graph_update_required = true;
|
||||||
|
}
|
||||||
|
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return cuda_graph_update_required;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#ifdef USE_CUDA_GRAPH
|
#ifdef USE_CUDA_GRAPH
|
||||||
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||||
|
@ -2398,28 +2428,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_cuda_graph) {
|
if (use_cuda_graph) {
|
||||||
if (cuda_ctx->cuda_graph->instance == nullptr) {
|
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required);
|
||||||
cuda_graph_update_required = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the graph size has changed
|
|
||||||
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
|
|
||||||
cuda_graph_update_required = true;
|
|
||||||
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
|
||||||
// and store properties to allow this comparison for the next token
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
||||||
bool has_matching_properties = true;
|
|
||||||
if (!cuda_graph_update_required) {
|
|
||||||
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
|
||||||
}
|
|
||||||
if (!has_matching_properties) {
|
|
||||||
cuda_graph_update_required = true;
|
|
||||||
}
|
|
||||||
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||||
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
|
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue