FIx issues raised in comments
This commit is contained in:
		
							parent
							
								
									cec409aa98
								
							
						
					
					
						commit
						c8dd0e7c1c
					
				
					 1 changed files with 43 additions and 19 deletions
				
			
		
							
								
								
									
										62
									
								
								ggml-cuda.cu
									
										
									
									
									
								
							
							
						
						
									
										62
									
								
								ggml-cuda.cu
									
										
									
									
									
								
							|  | @ -2405,23 +2405,33 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { | ||||||
|     GGML_UNUSED(backend); |     GGML_UNUSED(backend); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | #if (CUDART_VERSION >= 12000) | ||||||
|  | #define USE_CUDA_GRAPH | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  | #ifdef USE_CUDA_GRAPH | ||||||
|  | #define MAX_NODES_IN_CUDA_GRAPH 10000 | ||||||
| struct ggml_cudaGraph { | struct ggml_cudaGraph { | ||||||
|     int count=0; |     int count=0; | ||||||
|     cudaGraph_t graph = nullptr; |     cudaGraph_t graph = nullptr; | ||||||
|     cudaGraphExec_t instance = nullptr; |     cudaGraphExec_t instance = nullptr; | ||||||
|     size_t numNodes = 0; |     size_t numNodes = 0; | ||||||
|     int softmax_ne0 = 0; |     int softmax_ne0 = 0; | ||||||
|  |     cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; | ||||||
|  |     CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH]; | ||||||
|  |     cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH]; | ||||||
| }; | }; | ||||||
|  | #endif | ||||||
| 
 | 
 | ||||||
| GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { | GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { | ||||||
|     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; |     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; | ||||||
| 
 | 
 | ||||||
|     ggml_cuda_set_device(cuda_ctx->device); |     ggml_cuda_set_device(cuda_ctx->device); | ||||||
| 
 | 
 | ||||||
|  | #ifdef USE_CUDA_GRAPH | ||||||
|     // Objects required for CUDA Graph |     // Objects required for CUDA Graph | ||||||
| #define MAX_NODES_IN_CUDA_GRAPH 10000 |     static ggml_cudaGraph cudaGraph; | ||||||
|     static ggml_cudaGraph cudaGraph; //TO DO move this to a suitable persistant location (and avoid use of static memory) |     bool useCudaGraph = (cudaGraph.count>=7); //avoid CUDA graphs on first few steps due to incompatible initialisations. | ||||||
|     bool useCudaGraph = (cudaGraph.count>=2); //avoid CUDA graphs on first 2 steps due to incompatible initialisations. |  | ||||||
|     char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH]; |     char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH]; | ||||||
|     bool cudaGraphUpdateRequired = false; |     bool cudaGraphUpdateRequired = false; | ||||||
|     // pointer to CUDA cpy kernel, which is required to identify |     // pointer to CUDA cpy kernel, which is required to identify | ||||||
|  | @ -2458,6 +2468,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t | ||||||
|         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal)); |         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | #else | ||||||
|  |     bool useCudaGraph = false; | ||||||
|  |     bool cudaGraphUpdateRequired = false; | ||||||
|  | #endif | ||||||
|  |      | ||||||
|     // Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph. |     // Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph. | ||||||
|     // With use of CUDA graphs, the execution will be performed by the graph launch. |     // With use of CUDA graphs, the execution will be performed by the graph launch. | ||||||
|     if(!useCudaGraph || cudaGraphUpdateRequired) { |     if(!useCudaGraph || cudaGraphUpdateRequired) { | ||||||
|  | @ -2486,6 +2501,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t | ||||||
|     } |     } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     #ifdef USE_CUDA_GRAPH | ||||||
|     if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture |     if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture | ||||||
|         CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph)); |         CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph)); | ||||||
|     } |     } | ||||||
|  | @ -2498,26 +2514,26 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t | ||||||
| 
 | 
 | ||||||
|         // Perform update to graph (if required for this token), and change copy parameter (required for every token) |         // Perform update to graph (if required for this token), and change copy parameter (required for every token) | ||||||
| 
 | 
 | ||||||
|         cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH]; |  | ||||||
|         CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH]; |  | ||||||
|         cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH]; |  | ||||||
| 
 |  | ||||||
|         if(cudaGraphUpdateRequired) { |         if(cudaGraphUpdateRequired) { | ||||||
|             // Extract nodes from graph |             // Extract nodes from graph | ||||||
|             if(cudaGraph.numNodes == 0) { |             if(cudaGraph.numNodes == 0) { | ||||||
|                 CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes)); |                 CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes)); | ||||||
|             } |             } | ||||||
|             CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nodes, &cudaGraph.numNodes)); |             CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, cudaGraph.nodes, &cudaGraph.numNodes)); | ||||||
| 
 | 
 | ||||||
|             // Loop over nodes, and extract kernel parameters fro each node |             // Loop over nodes, and extract kernel parameters fro each node | ||||||
|             for(size_t i=0; i<cudaGraph.numNodes; i++) { |             for(size_t i=0; i<cudaGraph.numNodes; i++) { | ||||||
|                 // We currently get a set of params using both driver and runtime, to work around an issue (see below) |                 CUgraphNodeType nodeType; | ||||||
|                 CU_CHECK(cuGraphKernelNodeGetParams(nodes[i], ¶msDriver[i])); // Get params using driver |                 CU_CHECK(cuGraphNodeGetType(cudaGraph.nodes[i], &nodeType)); | ||||||
|                 cudaError_t statRT = cudaGraphKernelNodeGetParams(nodes[i], ¶msRuntime[i]); // Get params using runtime |                 if (nodeType == CU_GRAPH_NODE_TYPE_KERNEL) { | ||||||
|                 if(statRT == 98) { |                     // We currently get a set of params using both driver and runtime, to work around an issue (see below) | ||||||
|                     // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. |                     CU_CHECK(cuGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i])); // Get params using driver | ||||||
|                     // We don't need to update blas nodes, so clear error and move on. |                     auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsRuntime[i]); // Get params using runtime | ||||||
|                     cudaGetLastError(); |                     if(statRT == cudaErrorInvalidDeviceFunction) { | ||||||
|  |                         // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node. | ||||||
|  |                         // We don't need to update blas nodes, so clear error and move on. | ||||||
|  |                         cudaGetLastError(); | ||||||
|  |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  | @ -2529,22 +2545,30 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t | ||||||
|         if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured |         if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured | ||||||
|             int k=0; |             int k=0; | ||||||
|             for(size_t i=0; i<cudaGraph.numNodes; i++) { |             for(size_t i=0; i<cudaGraph.numNodes; i++) { | ||||||
|                 if(paramsRuntime[i].func == ggmlCudaCpyFn) { |                 if(cudaGraph.paramsRuntime[i].func == ggmlCudaCpyFn) { | ||||||
|                     char** updatedKernelArgPointer = updatedKernelArg[k++]; |                     char** updatedKernelArgPointer = updatedKernelArg[k++]; | ||||||
|                     paramsDriver[i].kernelParams[1] = updatedKernelArgPointer; |                     cudaGraph.paramsDriver[i].kernelParams[1] = updatedKernelArgPointer; | ||||||
|                     CU_CHECK(cuGraphKernelNodeSetParams(nodes[i], ¶msDriver[i])); |                     CU_CHECK(cuGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i])); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         // Update graph executable |         // Update graph executable | ||||||
|         cudaGraphExecUpdateResultInfo resultInfo; |         cudaGraphExecUpdateResultInfo resultInfo; | ||||||
|         CUDA_CHECK(cudaGraphExecUpdate(cudaGraph.instance, cudaGraph.graph, &resultInfo)); |         auto stat = cudaGraphExecUpdate(cudaGraph.instance, cudaGraph.graph, &resultInfo); | ||||||
|  |         if(stat == cudaErrorGraphExecUpdateFailure) | ||||||
|  |         { | ||||||
|  |             // The pre-existing graph exec cannot be updated due to violated constraints | ||||||
|  |             // so instead clar error and re-instantiate | ||||||
|  |             cudaGetLastError(); | ||||||
|  |             CUDA_CHECK(cudaGraphInstantiate(&cudaGraph.instance, cudaGraph.graph, NULL, NULL, 0));           | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|         // Launch graph |         // Launch graph | ||||||
|         CUDA_CHECK(cudaGraphLaunch(cudaGraph.instance, cuda_ctx->stream())); |         CUDA_CHECK(cudaGraphLaunch(cudaGraph.instance, cuda_ctx->stream())); | ||||||
|     } |     } | ||||||
|     cudaGraph.count++; |     cudaGraph.count++; | ||||||
|  | #endif | ||||||
|     return GGML_STATUS_SUCCESS; |     return GGML_STATUS_SUCCESS; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue