Refactor: Improves structure and abstractions by moving CUDA graph evaluation and capture to its own function.
This commit is contained in:
parent
ed10ff58a6
commit
37518b7dda
1 changed files with 85 additions and 76 deletions
|
@ -2438,11 +2438,95 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||||
|
[[maybe_unused]] std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph,
|
||||||
|
bool & cuda_graph_update_required) {
|
||||||
|
|
||||||
|
while (!graph_evaluated_or_captured) {
|
||||||
|
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||||
|
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||||
|
if (!use_cuda_graph || cuda_graph_update_required) {
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
|
||||||
|
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
||||||
|
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||||
|
if (node->src[j] != nullptr) {
|
||||||
|
assert(node->src[j]->buffer);
|
||||||
|
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
|
||||||
|
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
|
||||||
|
if (!ok) {
|
||||||
|
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||||
|
}
|
||||||
|
GGML_ASSERT(ok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef USE_CUDA_GRAPH
|
||||||
|
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
|
||||||
|
if (cuda_ctx->cuda_graph->graph != nullptr) {
|
||||||
|
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
|
||||||
|
cuda_ctx->cuda_graph->graph = nullptr;
|
||||||
|
}
|
||||||
|
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
if (disable_cuda_graphs_due_to_failed_capture) {
|
||||||
|
use_cuda_graph = false;
|
||||||
|
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
|
||||||
|
#ifndef NDEBUG
|
||||||
|
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||||
|
} else {
|
||||||
|
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_cuda_graph) {
|
||||||
|
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||||
|
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
||||||
|
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
|
||||||
|
|
||||||
|
// Update graph executable
|
||||||
|
update_cuda_graph_executable(cuda_ctx);
|
||||||
|
|
||||||
|
// Launch graph
|
||||||
|
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||||
|
#else
|
||||||
|
graph_evaluated_or_captured = true;
|
||||||
|
#endif // USE_CUDA_GRAPH
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||||
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
|
||||||
|
|
||||||
ggml_cuda_set_device(cuda_ctx->device);
|
ggml_cuda_set_device(cuda_ctx->device);
|
||||||
|
|
||||||
|
// vector of pointers to CUDA cpy kernels, which are required to identify
|
||||||
|
// kernel parameters which need updated in the graph for each token
|
||||||
|
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
|
||||||
|
|
||||||
#ifdef USE_CUDA_GRAPH
|
#ifdef USE_CUDA_GRAPH
|
||||||
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
||||||
|
|
||||||
|
@ -2453,9 +2537,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||||
|
|
||||||
bool use_cuda_graph = true;
|
bool use_cuda_graph = true;
|
||||||
bool cuda_graph_update_required = false;
|
bool cuda_graph_update_required = false;
|
||||||
// vector of pointers to CUDA cpy kernels, which are required to identify
|
|
||||||
// kernel parameters which need updated in the graph for each token
|
|
||||||
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
|
|
||||||
|
|
||||||
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
||||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
|
||||||
|
@ -2559,79 +2640,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
||||||
|
|
||||||
bool graph_evaluated_or_captured = false;
|
bool graph_evaluated_or_captured = false;
|
||||||
|
|
||||||
while (!graph_evaluated_or_captured) {
|
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
|
||||||
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
|
||||||
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
|
||||||
if (!use_cuda_graph || cuda_graph_update_required) {
|
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
||||||
ggml_tensor * node = cgraph->nodes[i];
|
|
||||||
|
|
||||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifndef NDEBUG
|
|
||||||
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|
|
||||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
|
||||||
if (node->src[j] != nullptr) {
|
|
||||||
assert(node->src[j]->buffer);
|
|
||||||
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
|
|
||||||
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
|
|
||||||
if (!ok) {
|
|
||||||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
|
||||||
}
|
|
||||||
GGML_ASSERT(ok);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef USE_CUDA_GRAPH
|
|
||||||
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
|
|
||||||
if (cuda_ctx->cuda_graph->graph != nullptr) {
|
|
||||||
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
|
|
||||||
cuda_ctx->cuda_graph->graph = nullptr;
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
|
||||||
|
|
||||||
#if 0
|
|
||||||
if (disable_cuda_graphs_due_to_failed_capture) {
|
|
||||||
use_cuda_graph = false;
|
|
||||||
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
|
|
||||||
#ifndef NDEBUG
|
|
||||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
|
|
||||||
#endif
|
|
||||||
} else {
|
|
||||||
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
|
||||||
} else {
|
|
||||||
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (use_cuda_graph) {
|
|
||||||
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
|
||||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
|
|
||||||
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
|
|
||||||
|
|
||||||
// Update graph executable
|
|
||||||
update_cuda_graph_executable(cuda_ctx);
|
|
||||||
|
|
||||||
// Launch graph
|
|
||||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
|
||||||
#else
|
|
||||||
graph_evaluated_or_captured = true;
|
|
||||||
#endif // USE_CUDA_GRAPH
|
|
||||||
}
|
|
||||||
|
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue