feat: cuda implementation for ggml_conv_transpose_1d (ggml/854)
				
					
				
			* conv transpose 1d passing test for 1d input and kernel * working for different input and output channel counts, added test for variable stride * initial draft appears to work with stride other than 1 * working with all old and new conv1d tests * added a test for large tensors * removed use cuda hardcoding * restored test-conv-transpose.c * removed unused arugments, and fixed bug where test failure would cause subsequent tests to fail * fixed accumulator bug * added test to test-backend-ops * fixed mistake * addressed review * fixed includes * removed blank lines * style and warning fixes * return failure when test fails * fix supports_op --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
		
							parent
							
								
									470939d483
								
							
						
					
					
						commit
						fde13b3bb9
					
				
					 4 changed files with 146 additions and 1 deletions
				
			
		|  | @ -29,6 +29,7 @@ | |||
| #include "ggml-cuda/tsembd.cuh" | ||||
| #include "ggml-cuda/unary.cuh" | ||||
| #include "ggml-cuda/upscale.cuh" | ||||
| #include "ggml-cuda/conv-transpose-1d.cuh" | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <array> | ||||
|  | @ -2261,6 +2262,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg | |||
|         case GGML_OP_IM2COL: | ||||
|             ggml_cuda_op_im2col(ctx, dst); | ||||
|             break; | ||||
|         case GGML_OP_CONV_TRANSPOSE_1D: | ||||
|             ggml_cuda_op_conv_transpose_1d(ctx,dst); | ||||
|             break; | ||||
|         case GGML_OP_POOL_2D: | ||||
|             ggml_cuda_op_pool2d(ctx, dst); | ||||
|             break; | ||||
|  | @ -2804,6 +2808,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons | |||
|                 ggml_type src0_type = op->src[0]->type; | ||||
|                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; | ||||
|             } break; | ||||
|         case GGML_OP_CONV_TRANSPOSE_1D: | ||||
|             { | ||||
|                 ggml_type src0_type = op->src[0]->type; | ||||
|                 ggml_type src1_type = op->src[1]->type; | ||||
|                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { | ||||
|                     return true; | ||||
|                 } | ||||
|                 return false; | ||||
|             } break; | ||||
|         case GGML_OP_NONE: | ||||
|         case GGML_OP_RESHAPE: | ||||
|         case GGML_OP_VIEW: | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue