ggml: metal unary exp & neg
There isn't much peformance gain though. Just for more op coverage Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
		
							parent
							
								
									d564c4b534
								
							
						
					
					
						commit
						3a2a97af28
					
				
					 2 changed files with 46 additions and 0 deletions
				
			
		|  | @ -138,6 +138,8 @@ enum ggml_metal_kernel_type { | |||
|     GGML_METAL_KERNEL_TYPE_SCALE_4, | ||||
|     GGML_METAL_KERNEL_TYPE_CLAMP, | ||||
|     GGML_METAL_KERNEL_TYPE_TANH, | ||||
|     GGML_METAL_KERNEL_TYPE_EXP, | ||||
|     GGML_METAL_KERNEL_TYPE_NEG, | ||||
|     GGML_METAL_KERNEL_TYPE_RELU, | ||||
|     GGML_METAL_KERNEL_TYPE_SIGMOID, | ||||
|     GGML_METAL_KERNEL_TYPE_GELU, | ||||
|  | @ -745,6 +747,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | |||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                       scale_4,                        true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP,                         clamp,                          true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                          tanh,                           true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP,                           exp,                            true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                           neg,                            true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                          relu,                           true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                       sigmoid,                        true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                          gelu,                           true); | ||||
|  | @ -1184,6 +1188,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | |||
|                 case GGML_UNARY_OP_GELU_QUICK: | ||||
|                 case GGML_UNARY_OP_SILU: | ||||
|                 case GGML_UNARY_OP_ELU: | ||||
|                 case GGML_UNARY_OP_EXP: | ||||
|                 case GGML_UNARY_OP_NEG: | ||||
|                     return ggml_is_contiguous(op->src[0]); | ||||
|                 default: | ||||
|                     return false; | ||||
|  | @ -1751,6 +1757,30 @@ static void ggml_metal_encode_node( | |||
| 
 | ||||
|                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||
|                 } break; | ||||
|                 case GGML_UNARY_OP_EXP: | ||||
|                 { | ||||
|                     id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline; | ||||
| 
 | ||||
|                     [encoder setComputePipelineState:pipeline]; | ||||
|                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1]; | ||||
| 
 | ||||
|                     const int64_t n = ggml_nelements(dst); | ||||
| 
 | ||||
|                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||
|                 } break; | ||||
|                 case GGML_UNARY_OP_NEG: | ||||
|                 { | ||||
|                     id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; | ||||
| 
 | ||||
|                     [encoder setComputePipelineState:pipeline]; | ||||
|                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1]; | ||||
| 
 | ||||
|                     const int64_t n = ggml_nelements(dst); | ||||
| 
 | ||||
|                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; | ||||
|                 } break; | ||||
|                 case GGML_UNARY_OP_RELU: | ||||
|                 { | ||||
|                     id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline; | ||||
|  |  | |||
|  | @ -840,6 +840,22 @@ kernel void kernel_tanh( | |||
|     dst[tpig] = precise::tanh(x); | ||||
| } | ||||
| 
 | ||||
| kernel void kernel_exp( | ||||
|         device const float * src0, | ||||
|         device       float * dst, | ||||
|         uint tpig[[thread_position_in_grid]]) { | ||||
|     device const float & x = src0[tpig]; | ||||
|     dst[tpig] = precise::exp(x); | ||||
| } | ||||
| 
 | ||||
| kernel void kernel_neg( | ||||
|         device const float * src0, | ||||
|         device       float * dst, | ||||
|         uint tpig[[thread_position_in_grid]]) { | ||||
|     device const float & x = src0[tpig]; | ||||
|     dst[tpig] = -x; | ||||
| } | ||||
| 
 | ||||
| constant float GELU_COEF_A     = 0.044715f; | ||||
| constant float GELU_QUICK_COEF = -1.702f; | ||||
| constant float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f; | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue