Fall back if graph capture fails and address other comments
This commit is contained in:
parent
909e4c664b
commit
58199503a8
2 changed files with 77 additions and 37 deletions
108
ggml-cuda.cu
108
ggml-cuda.cu
|
@ -48,11 +48,20 @@
|
|||
|
||||
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
|
||||
[[noreturn]]
|
||||
static bool disable_cuda_graphs_due_to_failed_capture = false;
|
||||
|
||||
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
|
||||
int id = -1; // in case cudaGetDevice fails
|
||||
cudaGetDevice(&id);
|
||||
|
||||
if(strcmp(msg,"operation not permitted when stream is capturing")==0 ||
|
||||
strcmp(msg,"operation failed due to a previous error during capture")==0) {
|
||||
// CUDA graph capture has failed, but we can fall back to regular stream-based CUDA
|
||||
// so mark as failed, clear the error and return.
|
||||
disable_cuda_graphs_due_to_failed_capture = true;
|
||||
cudaGetLastError();
|
||||
return;
|
||||
}
|
||||
fprintf(stderr, "CUDA error: %s\n", msg);
|
||||
fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line);
|
||||
fprintf(stderr, " %s\n", stmt);
|
||||
|
@ -2428,6 +2437,7 @@ struct ggml_cuda_graph {
|
|||
cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
|
||||
bool disable_due_to_gpu_arch = false;
|
||||
bool disable_due_to_too_many_updates = false;
|
||||
bool disable_due_to_failed_graph_capture = false;
|
||||
int number_consecutive_updates = 0;
|
||||
ggml_graph_node_properties ggml_graph_properties[MAX_NODES_IN_CUDA_GRAPH];
|
||||
};
|
||||
|
@ -2436,26 +2446,28 @@ struct ggml_cuda_graph {
|
|||
const bool disable_cuda_graphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr);
|
||||
|
||||
GGML_CALL static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
graph_node_properties->node_address = node;
|
||||
graph_node_properties->node_address = node->data;
|
||||
graph_node_properties->node_op = node->op;
|
||||
for(int i=0; i<GGML_MAX_DIMS; i++) {
|
||||
graph_node_properties->ne[i] = node->ne[i];
|
||||
graph_node_properties->nb[i] = node->nb[i];
|
||||
}
|
||||
for(int i=0; i<GGML_MAX_SRC; i++) {
|
||||
graph_node_properties->src_address[i] = node->src[i];
|
||||
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_CALL static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
if(node != graph_node_properties->node_address) return false;
|
||||
if(node->data != graph_node_properties->node_address &&
|
||||
node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW) return false;
|
||||
if(node->op != graph_node_properties->node_op) return false;
|
||||
for(int i=0; i<GGML_MAX_DIMS; i++) {
|
||||
if(node->ne[i] != graph_node_properties->ne[i]) return false;
|
||||
if(node->nb[i] != graph_node_properties->nb[i]) return false;
|
||||
}
|
||||
for(int i=0; i<GGML_MAX_SRC; i++) {
|
||||
if(node->src[i] != graph_node_properties->src_address[i]) return false;
|
||||
if(node->src[i] && node->src[i]->data != graph_node_properties->src_address[i] &&
|
||||
node->op != GGML_OP_CPY && node->op != GGML_OP_VIEW) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -2467,46 +2479,54 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// Objects required for CUDA Graph
|
||||
static ggml_cuda_graph cuda_graph;
|
||||
bool use_cuda_graph = (cuda_graph.count >= 7); //avoid CUDA graphs on first few steps due to incompatible initialisations.
|
||||
if(cuda_ctx->cuda_graph == nullptr)
|
||||
{
|
||||
cuda_ctx->cuda_graph = (ggml_cuda_graph *) malloc(sizeof(ggml_cuda_graph));
|
||||
}
|
||||
bool use_cuda_graph = (cuda_ctx->cuda_graph->count >= 7); //avoid CUDA graphs on first few steps due to incompatible initialisations.
|
||||
char ** updated_kernel_arg[MAX_NODES_IN_CUDA_GRAPH];
|
||||
bool cuda_graph_update_required = false;
|
||||
// pointer to CUDA cpy kernel, which is required to identify
|
||||
// kernel parameters which need updated in the graph for each token
|
||||
void * ggml_cuda_cpy_fn_ptr = nullptr;
|
||||
|
||||
if(cuda_graph.count == 0){
|
||||
if(cuda_ctx->cuda_graph->count == 0){
|
||||
if (ggml_cuda_info().devices[cuda_ctx->device].cc < 800){
|
||||
cuda_graph.disable_due_to_gpu_arch=true;
|
||||
cuda_ctx->cuda_graph->disable_due_to_gpu_arch=true;
|
||||
}
|
||||
}
|
||||
|
||||
// Disable CUDA graphs in presence of env var, old GPU or use-case which is changing too rapidly.
|
||||
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
|
||||
// or previous graph capture failure.
|
||||
// Also disable for multi-gpu for now. TO DO investigate
|
||||
if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || cuda_graph.disable_due_to_too_many_updates ||
|
||||
if(disable_cuda_graphs || cuda_ctx->cuda_graph->disable_due_to_gpu_arch ||
|
||||
cuda_ctx->cuda_graph->disable_due_to_too_many_updates || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture ||
|
||||
ggml_backend_cuda_get_device_count() > 1){
|
||||
use_cuda_graph = false;
|
||||
}
|
||||
|
||||
if(use_cuda_graph) {
|
||||
|
||||
if(cuda_graph.instance == nullptr) cuda_graph_update_required=true;
|
||||
if(cuda_ctx->cuda_graph->instance == nullptr) cuda_graph_update_required=true;
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
// and store properties to allow this comparison for the next token
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool has_matching_properties = true;
|
||||
if(!cuda_graph_update_required) {
|
||||
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_graph.ggml_graph_properties[i]);
|
||||
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||
}
|
||||
if(!has_matching_properties) cuda_graph_update_required = true;
|
||||
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_graph.ggml_graph_properties[i]);
|
||||
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
int k=0;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
if(node->op == GGML_OP_MUL_MAT_ID) {
|
||||
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
|
||||
}
|
||||
if(node->op == GGML_OP_SOFT_MAX) {
|
||||
if(node->src[1]->ne[1] > 1){
|
||||
use_cuda_graph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate
|
||||
|
@ -2524,12 +2544,12 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
|
||||
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||
if(cuda_graph_update_required) {
|
||||
cuda_graph.number_consecutive_updates++;
|
||||
cuda_ctx->cuda_graph->number_consecutive_updates++;
|
||||
}
|
||||
else {
|
||||
cuda_graph.number_consecutive_updates = 0;
|
||||
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
|
||||
}
|
||||
if (cuda_graph.number_consecutive_updates >= 4) cuda_graph.disable_due_to_too_many_updates = true;
|
||||
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
|
||||
}
|
||||
|
||||
if(use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
|
||||
|
@ -2540,11 +2560,16 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
|
||||
bool graph_evaluated_or_captured = false;
|
||||
|
||||
while(!graph_evaluated_or_captured) {
|
||||
// Temporarily avoid indenting here (and below the following if) to make code review easier
|
||||
|
||||
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
||||
// With the use of CUDA graphs, the execution will be performed by the graph launch.
|
||||
if(!use_cuda_graph || cuda_graph_update_required) {
|
||||
//temporarily avoid indenting here to make code review easier
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
|
@ -2571,12 +2596,23 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph));
|
||||
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
||||
if(disable_cuda_graphs_due_to_failed_capture) {
|
||||
use_cuda_graph = false;
|
||||
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
|
||||
}
|
||||
else {
|
||||
graph_evaluated_or_captured = true; // CUDA graph has been captured
|
||||
}
|
||||
}
|
||||
else {
|
||||
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
||||
}
|
||||
}
|
||||
if(use_cuda_graph){
|
||||
|
||||
if(cuda_graph.instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_graph.instance, cuda_graph.graph, NULL, NULL, 0));
|
||||
if(cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
|
||||
|
||||
|
@ -2584,19 +2620,19 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
|
||||
if(cuda_graph_update_required) {
|
||||
// Extract nodes from graph
|
||||
if(cuda_graph.num_nodes == 0) {
|
||||
if(cuda_ctx->cuda_graph->num_nodes == 0) {
|
||||
// First call with null argument gets number of nodes in graph
|
||||
CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, nullptr, &cuda_graph.num_nodes));
|
||||
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
|
||||
}
|
||||
// Subsequent call with non-null argument gets nodes
|
||||
CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, cuda_graph.nodes, &cuda_graph.num_nodes));
|
||||
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes, &cuda_ctx->cuda_graph->num_nodes));
|
||||
|
||||
// Loop over nodes, and extract kernel parameters from each node
|
||||
for(size_t i=0; i<cuda_graph.num_nodes; i++) {
|
||||
for(size_t i=0; i<cuda_ctx->cuda_graph->num_nodes; i++) {
|
||||
cudaGraphNodeType node_type;
|
||||
CUDA_CHECK(cudaGraphNodeGetType(cuda_graph.nodes[i], &node_type));
|
||||
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
|
||||
if (node_type == cudaGraphNodeTypeKernel) {
|
||||
auto stat = cudaGraphKernelNodeGetParams(cuda_graph.nodes[i], &cuda_graph.params[i]); // Get params using runtime
|
||||
auto stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
|
||||
if(stat == cudaErrorInvalidDeviceFunction) {
|
||||
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
|
||||
// We don't need to update blas nodes, so clear error and move on.
|
||||
|
@ -2613,31 +2649,31 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|||
// replace that argument with the updated value in the CUDA graph
|
||||
if(!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
|
||||
int k=0;
|
||||
for(size_t i=0; i<cuda_graph.num_nodes; i++) {
|
||||
if(cuda_graph.params[i].func == ggml_cuda_cpy_fn_ptr) {
|
||||
for(size_t i=0; i<cuda_ctx->cuda_graph->num_nodes; i++) {
|
||||
if(cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
|
||||
char ** updated_kernel_arg_ptr = updated_kernel_arg[k++];
|
||||
cuda_graph.params[i].kernelParams[1] = updated_kernel_arg_ptr;
|
||||
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_graph.nodes[i], &cuda_graph.params[i]));
|
||||
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
|
||||
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update graph executable
|
||||
cudaGraphExecUpdateResultInfo result_info;
|
||||
auto stat = cudaGraphExecUpdate(cuda_graph.instance, cuda_graph.graph, &result_info);
|
||||
auto stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
||||
if(stat == cudaErrorGraphExecUpdateFailure) {
|
||||
// The pre-existing graph exec cannot be updated due to violated constraints
|
||||
// so instead clear error and re-instantiate
|
||||
cudaGetLastError();
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_graph.instance, cuda_graph.graph, NULL, NULL, 0));
|
||||
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(stat == cudaSuccess);
|
||||
}
|
||||
// Launch graph
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_graph.instance, cuda_ctx->stream()));
|
||||
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
|
||||
}
|
||||
cuda_graph.count++;
|
||||
cuda_ctx->cuda_graph->count++;
|
||||
#endif // USE_CUDA_GRAPH
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
|
|
@ -172,7 +172,6 @@
|
|||
|
||||
#define GGML_CUDA_MAX_STREAMS 8
|
||||
|
||||
[[noreturn]]
|
||||
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
|
||||
|
||||
#define CUDA_CHECK_GEN(err, success, error_fn) \
|
||||
|
@ -479,6 +478,8 @@ struct ggml_tensor_extra_gpu {
|
|||
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
|
||||
};
|
||||
|
||||
struct ggml_cuda_graph;
|
||||
|
||||
struct ggml_backend_cuda_context {
|
||||
int device;
|
||||
std::string name;
|
||||
|
@ -487,6 +488,8 @@ struct ggml_backend_cuda_context {
|
|||
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
|
||||
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
|
||||
ggml_cuda_graph * cuda_graph = nullptr;
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
|
@ -506,6 +509,7 @@ struct ggml_backend_cuda_context {
|
|||
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
|
||||
}
|
||||
}
|
||||
if(cuda_graph != nullptr) free(cuda_graph);
|
||||
}
|
||||
|
||||
cudaStream_t stream(int device, int stream) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue