metal : pre-compute ALiBi slopes

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-16 10:17:59 +02:00
parent ac91033ccb
commit 833490b16f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 25 additions and 14 deletions

View file

@ -1194,6 +1194,14 @@ static bool ggml_metal_graph_compute(
const float scale = ((float *) dst->op_params)[0];
const float max_bias = ((float *) dst->op_params)[1];
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
@ -1212,6 +1220,9 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];

View file

@ -358,6 +358,9 @@ kernel void kernel_soft_max(
constant int64_t & ne02,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
@ -377,15 +380,12 @@ kernel void kernel_soft_max(
// ALiBi
if (max_bias > 0.0f) {
const uint32_t n_head_kv = ne02;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int64_t h = i02;
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
@ -462,6 +462,9 @@ kernel void kernel_soft_max_4(
constant int64_t & ne02,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
@ -480,15 +483,12 @@ kernel void kernel_soft_max_4(
float slope = 0.0f;
if (max_bias > 0.0f) {
const uint32_t n_head_kv = ne02;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int64_t h = i02;
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max