diff --git a/ggml-metal.m b/ggml-metal.m index 1c67334d1..09df3a817 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1194,6 +1194,14 @@ static bool ggml_metal_graph_compute( const float scale = ((float *) dst->op_params)[0]; const float max_bias = ((float *) dst->op_params)[1]; + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + const uint32_t n_head_kv = nrows_x/nrows_y; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { @@ -1212,6 +1220,9 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6]; [encoder setBytes:&scale length:sizeof(scale) atIndex:7]; [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:9]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:10]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; diff --git a/ggml-metal.metal b/ggml-metal.metal index c5da88e14..09ebcc9e3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -358,6 +358,9 @@ kernel void kernel_soft_max( constant int64_t & ne02, constant float & scale, constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -377,15 +380,12 @@ kernel void kernel_soft_max( // ALiBi if (max_bias > 0.0f) { - const uint32_t n_head_kv = ne02; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv)); - - const float m0 = pow(2.0f, -(max_bias ) / n_head_log2); - const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int64_t h = i02; - slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1); + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); } // parallel max @@ -462,6 +462,9 @@ kernel void kernel_soft_max_4( constant int64_t & ne02, constant float & scale, constant float & max_bias, + constant float & m0, + constant float & m1, + constant uint32_t & n_head_log2, threadgroup float * buf [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], @@ -480,15 +483,12 @@ kernel void kernel_soft_max_4( float slope = 0.0f; if (max_bias > 0.0f) { - const uint32_t n_head_kv = ne02; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv)); - - const float m0 = pow(2.0f, -(max_bias ) / n_head_log2); - const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int64_t h = i02; - slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1); + const float base = h < n_head_log2 ? m0 : m1; + const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); } // parallel max