Refactor: Moves cuda graph executable update step to separate function.
This commit is contained in:
parent
ecebbd292d
commit
ba0533100d
1 changed files with 24 additions and 15 deletions
|
@ -2337,6 +2337,28 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||||
|
|
||||||
|
cudaGraphExecUpdateResultInfo result_info;
|
||||||
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||||
|
if (stat == cudaErrorGraphExecUpdateFailure) {
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
||||||
|
#endif
|
||||||
|
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||||
|
// so instead clear error and re-instantiate
|
||||||
|
cudaGetLastError();
|
||||||
|
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
||||||
|
cuda_ctx->cuda_graph->instance = nullptr;
|
||||||
|
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(stat == cudaSuccess);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
|
|
||||||
|
@ -2585,21 +2607,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update graph executable
|
// Update graph executable
|
||||||
cudaGraphExecUpdateResultInfo result_info;
|
update_cuda_graph_executable(cuda_ctx);
|
||||||
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
|
||||||
if (stat == cudaErrorGraphExecUpdateFailure) {
|
|
||||||
#ifndef NDEBUG
|
|
||||||
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
|
||||||
#endif
|
|
||||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
|
||||||
// so instead clear error and re-instantiate
|
|
||||||
cudaGetLastError();
|
|
||||||
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
|
|
||||||
cuda_ctx->cuda_graph->instance = nullptr;
|
|
||||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(stat == cudaSuccess);
|
|
||||||
}
|
|
||||||
// Launch graph
|
// Launch graph
|
||||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||||
#else
|
#else
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue