cuda : add batched cuBLAS GEMM for faster attention (#3749)
* cmake : add helper for faster CUDA builds * batched : add NGL arg * ggml : skip nops in compute_forward * cuda : minor indentation * cuda : batched cuBLAS GEMMs for src0 F16 and src1 F32 (attention ops) * Apply suggestions from code review These changes plus: ```c++ #define cublasGemmBatchedEx hipblasGemmBatchedEx ``` are needed to compile with ROCM. I haven't done performance testing, but it seems to work. I couldn't figure out how to propose a change for lines outside what the pull changed, also this is the first time trying to create a multi-part review so please forgive me if I mess something up. * cuda : add ROCm / hipBLAS cublasGemmBatchedEx define * cuda : add cublasGemmStridedBatchedEx for non-broadcasted cases * cuda : reduce mallocs in cublasGemmBatchedEx branch * cuda : add TODO for calling cublas from kernel + using mem pool --------- Co-authored-by: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									daab3d7f45
								
							
						
					
					
						commit
						2b4ea35e56
					
				
					 4 changed files with 193 additions and 13 deletions
				
			
		|  | @ -331,6 +331,7 @@ if (LLAMA_CUBLAS) | ||||||
|             set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics |             set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics | ||||||
|         else() |         else() | ||||||
|             set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics |             set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics | ||||||
|  |             #set(CMAKE_CUDA_ARCHITECTURES "") # use this to compile much faster, but only F16 models work | ||||||
|         endif() |         endif() | ||||||
|     endif() |     endif() | ||||||
|     message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") |     message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") | ||||||
|  |  | ||||||
|  | @ -11,7 +11,7 @@ int main(int argc, char ** argv) { | ||||||
|     gpt_params params; |     gpt_params params; | ||||||
| 
 | 
 | ||||||
|     if (argc == 1 || argv[1][0] == '-') { |     if (argc == 1 || argv[1][0] == '-') { | ||||||
|         printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN]\n" , argv[0]); |         printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]); | ||||||
|         return 1 ; |         return 1 ; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -21,6 +21,9 @@ int main(int argc, char ** argv) { | ||||||
|     // total length of the sequences including the prompt
 |     // total length of the sequences including the prompt
 | ||||||
|     int n_len = 32; |     int n_len = 32; | ||||||
| 
 | 
 | ||||||
|  |     // number of layers to offload to the GPU
 | ||||||
|  |     int n_gpu_layers = 0; | ||||||
|  | 
 | ||||||
|     if (argc >= 2) { |     if (argc >= 2) { | ||||||
|         params.model = argv[1]; |         params.model = argv[1]; | ||||||
|     } |     } | ||||||
|  | @ -37,6 +40,10 @@ int main(int argc, char ** argv) { | ||||||
|         n_len = std::atoi(argv[4]); |         n_len = std::atoi(argv[4]); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     if (argc >= 6) { | ||||||
|  |         n_gpu_layers = std::atoi(argv[5]); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     if (params.prompt.empty()) { |     if (params.prompt.empty()) { | ||||||
|         params.prompt = "Hello my name is"; |         params.prompt = "Hello my name is"; | ||||||
|     } |     } | ||||||
|  | @ -49,7 +56,7 @@ int main(int argc, char ** argv) { | ||||||
| 
 | 
 | ||||||
|     llama_model_params model_params = llama_model_default_params(); |     llama_model_params model_params = llama_model_default_params(); | ||||||
| 
 | 
 | ||||||
|     // model_params.n_gpu_layers = 99; // offload all layers to the GPU
 |     model_params.n_gpu_layers = n_gpu_layers; | ||||||
| 
 | 
 | ||||||
|     llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); |     llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										182
									
								
								ggml-cuda.cu
									
										
									
									
									
								
							
							
						
						
									
										182
									
								
								ggml-cuda.cu
									
										
									
									
									
								
							|  | @ -29,6 +29,8 @@ | ||||||
| #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) | #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) | ||||||
| #define cublasCreate hipblasCreate | #define cublasCreate hipblasCreate | ||||||
| #define cublasGemmEx hipblasGemmEx | #define cublasGemmEx hipblasGemmEx | ||||||
|  | #define cublasGemmBatchedEx hipblasGemmBatchedEx | ||||||
|  | #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx | ||||||
| #define cublasHandle_t hipblasHandle_t | #define cublasHandle_t hipblasHandle_t | ||||||
| #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS | #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS | ||||||
| #define cublasSetStream hipblasSetStream | #define cublasSetStream hipblasSetStream | ||||||
|  | @ -4345,13 +4347,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous | ||||||
|             break; |             break; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; |  | ||||||
|         const float xi = __half2float(x[ix]); |  | ||||||
| 
 |  | ||||||
|         const int row_y = col_x; |         const int row_y = col_x; | ||||||
| 
 | 
 | ||||||
|  |         const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; | ||||||
|         const int iy = channel*nrows_y + row_y; |         const int iy = channel*nrows_y + row_y; | ||||||
| 
 | 
 | ||||||
|  |         const float xi = __half2float(x[ix]); | ||||||
|  | 
 | ||||||
|         tmp += xi * y[iy]; |         tmp += xi * y[iy]; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -7013,7 +7015,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ | static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ | ||||||
|     GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); |     GGML_ASSERT(!ggml_is_transposed(src0)); | ||||||
|  |     GGML_ASSERT(!ggml_is_transposed(src1)); | ||||||
|     GGML_ASSERT(!ggml_is_permuted(src0)); |     GGML_ASSERT(!ggml_is_permuted(src0)); | ||||||
|     GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); |     GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); | ||||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F16); |     GGML_ASSERT(src0->type == GGML_TYPE_F16); | ||||||
|  | @ -7023,11 +7026,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor | ||||||
|     const int64_t ne01 = src0->ne[1]; |     const int64_t ne01 = src0->ne[1]; | ||||||
|     const int64_t ne02 = src0->ne[2]; |     const int64_t ne02 = src0->ne[2]; | ||||||
| 
 | 
 | ||||||
|     const int64_t ne12 = src1->ne[2]; |  | ||||||
| 
 |  | ||||||
|     const int64_t nb01 = src0->nb[1]; |     const int64_t nb01 = src0->nb[1]; | ||||||
|     const int64_t nb02 = src0->nb[2]; |     const int64_t nb02 = src0->nb[2]; | ||||||
| 
 | 
 | ||||||
|  |     const int64_t ne12 = src1->ne[2]; | ||||||
|  | 
 | ||||||
|     CUDA_CHECK(ggml_cuda_set_device(g_main_device)); |     CUDA_CHECK(ggml_cuda_set_device(g_main_device)); | ||||||
|     cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; |     cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; | ||||||
| 
 | 
 | ||||||
|  | @ -7046,6 +7049,159 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor | ||||||
|     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); |     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ | ||||||
|  |     GGML_ASSERT(!ggml_is_transposed(src0)); | ||||||
|  |     GGML_ASSERT(!ggml_is_transposed(src1)); | ||||||
|  |     GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); | ||||||
|  |     GGML_ASSERT(src0->type == GGML_TYPE_F16); | ||||||
|  |     GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||||
|  | 
 | ||||||
|  |     const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); | ||||||
|  |     const int64_t ne01 = src0->ne[1]; | ||||||
|  |     const int64_t ne02 = src0->ne[2]; | ||||||
|  |     const int64_t ne03 = src0->ne[3]; | ||||||
|  | 
 | ||||||
|  |     const int64_t nb01 = src0->nb[1]; | ||||||
|  |     const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); | ||||||
|  |     const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); | ||||||
|  | 
 | ||||||
|  |     const int64_t ne10 = src1->ne[0]; | ||||||
|  |     const int64_t ne11 = src1->ne[1]; | ||||||
|  |     const int64_t ne12 = src1->ne[2]; | ||||||
|  |     const int64_t ne13 = src1->ne[3]; | ||||||
|  | 
 | ||||||
|  |     const int64_t nb11 = src1->nb[1]; | ||||||
|  |     const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); | ||||||
|  |     const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); | ||||||
|  | 
 | ||||||
|  |     const int64_t ne1 = ggml_nelements(src1); | ||||||
|  |     const int64_t ne  = ggml_nelements(dst); | ||||||
|  | 
 | ||||||
|  |     CUDA_CHECK(ggml_cuda_set_device(g_main_device)); | ||||||
|  |     cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; | ||||||
|  | 
 | ||||||
|  |     int id; | ||||||
|  |     CUDA_CHECK(cudaGetDevice(&id)); | ||||||
|  |     CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); | ||||||
|  | 
 | ||||||
|  |     ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; | ||||||
|  |     void * src0_ddq = src0_extra->data_device[g_main_device]; | ||||||
|  |     half * src0_as_f16 = (half *) src0_ddq; | ||||||
|  | 
 | ||||||
|  |     ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; | ||||||
|  |     float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; | ||||||
|  | 
 | ||||||
|  |     ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; | ||||||
|  |     float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; | ||||||
|  | 
 | ||||||
|  |     // convert src1 to fp16 | ||||||
|  |     const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); | ||||||
|  |     GGML_ASSERT(to_fp16_cuda != nullptr); | ||||||
|  | 
 | ||||||
|  |     size_t src1_as = 0; | ||||||
|  |     half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); | ||||||
|  |     to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); | ||||||
|  | 
 | ||||||
|  |     size_t dst_as = 0; | ||||||
|  |     half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); | ||||||
|  | 
 | ||||||
|  |     GGML_ASSERT(ne12 % ne02 == 0); | ||||||
|  |     GGML_ASSERT(ne13 % ne03 == 0); | ||||||
|  | 
 | ||||||
|  |     // broadcast factors | ||||||
|  |     const int64_t r2 = ne12/ne02; | ||||||
|  |     const int64_t r3 = ne13/ne03; | ||||||
|  | 
 | ||||||
|  |     const half alpha_f16 = 1.0f; | ||||||
|  |     const half beta_f16  = 0.0f; | ||||||
|  | 
 | ||||||
|  | #if 0 | ||||||
|  |     // use cublasGemmEx | ||||||
|  |     { | ||||||
|  |         for (int i13 = 0; i13 < ne13; ++i13) { | ||||||
|  |             for (int i12 = 0; i12 < ne12; ++i12) { | ||||||
|  |                 int i03 = i13 / r3; | ||||||
|  |                 int i02 = i12 / r2; | ||||||
|  | 
 | ||||||
|  |                 CUBLAS_CHECK( | ||||||
|  |                         cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|  |                             ne01, ne11, ne10, | ||||||
|  |                             &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F, nb01/sizeof(half), | ||||||
|  |                                         (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), | ||||||
|  |                             &beta_f16,  (      char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, | ||||||
|  |                             CUBLAS_COMPUTE_16F, | ||||||
|  |                             CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |     if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { | ||||||
|  |         // there is no broadcast and src0, src1 are contiguous across dims 2, 3 | ||||||
|  |         // use cublasGemmStridedBatchedEx | ||||||
|  |         CUBLAS_CHECK( | ||||||
|  |         cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|  |                 ne01, ne11, ne10, | ||||||
|  |                 &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA | ||||||
|  |                             (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB | ||||||
|  |                 &beta_f16,  (      char *)     dst_f16, CUDA_R_16F, ne01,                dst->nb[2]/sizeof(float), // strideC | ||||||
|  |                 ne12*ne13, | ||||||
|  |                 CUBLAS_COMPUTE_16F, | ||||||
|  |                 CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||||
|  |     } else { | ||||||
|  |         // use cublasGemmBatchedEx | ||||||
|  |         // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000 | ||||||
|  |         const int ne23 = ne12*ne13; | ||||||
|  | 
 | ||||||
|  |         // TODO: avoid this alloc | ||||||
|  |         void ** ptrs = (void **) malloc(3*ne23*sizeof(void *)); | ||||||
|  | 
 | ||||||
|  |         for (int i13 = 0; i13 < ne13; ++i13) { | ||||||
|  |             for (int i12 = 0; i12 < ne12; ++i12) { | ||||||
|  |                 int i03 = i13 / r3; | ||||||
|  |                 int i02 = i12 / r2; | ||||||
|  | 
 | ||||||
|  |                 ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]; | ||||||
|  |                 ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2; | ||||||
|  |                 ptrs[2*ne23 + i12 + i13*ne12] = (char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // allocate device memory for pointers | ||||||
|  |         void ** ptrs_as = nullptr; | ||||||
|  |         CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *))); | ||||||
|  | 
 | ||||||
|  |         // TODO: this does not work for some reason -- not sure why? | ||||||
|  |         //size_t ptrs_s = 0; | ||||||
|  |         //ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s); | ||||||
|  | 
 | ||||||
|  |         // copy pointers to device | ||||||
|  |         CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice)); | ||||||
|  | 
 | ||||||
|  |         free(ptrs); | ||||||
|  | 
 | ||||||
|  |         CUBLAS_CHECK( | ||||||
|  |         cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, | ||||||
|  |                 ne01, ne11, ne10, | ||||||
|  |                 &alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half), | ||||||
|  |                             (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float), | ||||||
|  |                 &beta_f16,  (      void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01, | ||||||
|  |                 ne23, | ||||||
|  |                 CUBLAS_COMPUTE_16F, | ||||||
|  |                 CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||||||
|  | 
 | ||||||
|  |         // free device memory for pointers | ||||||
|  |         CUDA_CHECK(cudaFree(ptrs_as)); | ||||||
|  |         //ggml_cuda_pool_free(ptrs_as, ptrs_s); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  | 
 | ||||||
|  |     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); | ||||||
|  |     to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); | ||||||
|  | 
 | ||||||
|  |     ggml_cuda_pool_free(src1_as_f16, src1_as); | ||||||
|  |     ggml_cuda_pool_free(dst_f16, dst_as); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|     bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && |     bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && | ||||||
|         src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; |         src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; | ||||||
|  | @ -7058,10 +7214,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     // debug helpers | ||||||
|  |     //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); | ||||||
|  |     //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); | ||||||
|  |     //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); | ||||||
|  |     //printf("      %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); | ||||||
|  |     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); | ||||||
|  |     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); | ||||||
|  | 
 | ||||||
|     if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { |     if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { | ||||||
|  |         // KQ | ||||||
|         ggml_cuda_mul_mat_vec_p021(src0, src1, dst); |         ggml_cuda_mul_mat_vec_p021(src0, src1, dst); | ||||||
|     } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { |     } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { | ||||||
|  |         // KQV | ||||||
|         ggml_cuda_mul_mat_vec_nc(src0, src1, dst); |         ggml_cuda_mul_mat_vec_nc(src0, src1, dst); | ||||||
|  |     } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { | ||||||
|  |         ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); | ||||||
|     } else if (src0->type == GGML_TYPE_F32) { |     } else if (src0->type == GGML_TYPE_F32) { | ||||||
|         ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); |         ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); | ||||||
|     } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { |     } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { | ||||||
|  |  | ||||||
							
								
								
									
										4
									
								
								ggml.c
									
										
									
									
									
								
							
							
						
						
									
										4
									
								
								ggml.c
									
										
									
									
									
								
							|  | @ -16602,6 +16602,10 @@ static void ggml_compute_forward_cross_entropy_loss_back( | ||||||
| static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { | static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { | ||||||
|     GGML_ASSERT(params); |     GGML_ASSERT(params); | ||||||
| 
 | 
 | ||||||
|  |     if (tensor->op == GGML_OP_NONE) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
|     bool skip_cpu = ggml_cuda_compute_forward(params, tensor); |     bool skip_cpu = ggml_cuda_compute_forward(params, tensor); | ||||||
|     if (skip_cpu) { |     if (skip_cpu) { | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue