diff --git a/ggml-metal.m b/ggml-metal.m index bd25a98df..89b17ce5e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -45,6 +45,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(scale); GGML_METAL_DECL_KERNEL(silu); GGML_METAL_DECL_KERNEL(relu); + GGML_METAL_DECL_KERNEL(gelu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(get_rows_f16); @@ -135,6 +136,7 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(scale); GGML_METAL_ADD_KERNEL(silu); GGML_METAL_ADD_KERNEL(relu); + GGML_METAL_ADD_KERNEL(gelu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(get_rows_f16); @@ -420,6 +422,20 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_GELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + [encoder setComputePipelineState:ctx->pipeline_gelu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_SOFT_MAX: { if (encoder == nil) { @@ -526,7 +542,7 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne12 == 1); nth0 = 8; - nth1 = 4; + nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; case GGML_TYPE_Q2_K: diff --git a/ggml-metal.metal b/ggml-metal.metal index 43814ed09..745fe8ad3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -81,6 +81,17 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +constant float GELU_COEF_A = 0.044715f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + float x = src0[tpig]; + dst[tpig] = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + kernel void kernel_soft_max( device const float * src0, device float * dst, @@ -267,6 +278,8 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tptg[[threads_per_threadgroup]]) { const int nb = ne00/QK4_0; + const int8_t m8 = 8; + const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -276,33 +289,34 @@ kernel void kernel_mul_mat_q4_0_f32( const uint nth = tptg.x*tptg.y; const uint ith = tptg.y*tpitg.x + tpitg.y; - sum[ith] = 0.0f; + const int ix = tpitg.y/4; // 0 or 1 + const int iy = tpitg.y - 4*ix; // 0...3 - for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs; - device const float4 * y0p = (device const float4 *) (y + i*QK4_0); + const int first = 4 * iy; - const float d = (float)((x + i)->d); + float sumf = 0; - const uchar4 x0v = *(x0p + tpitg.y); - const float4 y0v = *(y0p + tpitg.y + 0); - const float4 y1v = *(y0p + tpitg.y + 4); + for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { - float acc = 0.0f; + const float d = (float)x[i].d; + + device const uint8_t * xl = x[i].qs + first; + device const float * yl = y + i * QK4_0 + first; + + float2 acc = {0.0f, 0.0f}; for (int j = 0; j < 4; ++j) { - const int x0 = x0v[j] & 0x0F; - const int x1 = x0v[j] >> 4; - const float y0 = y0v[j]; - const float y1 = y1v[j]; + acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8); + acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8); - acc += (x0 - 8)*y0 + (x1 - 8)*y1; } - sum[ith] += acc*d; + sumf += d * (acc[0] + acc[1]); } + sum[ith] = sumf; + // // Accumulate the sum from all threads in the threadgroup // This version is slightly faster than the commented out one below, @@ -357,6 +371,7 @@ kernel void kernel_mul_mat_f16_f32( uint3 tpig[[thread_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]]) { + const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; const int64_t im = tgpig.z;