metal : pre-compute ALiBi slopes
ggml-ci
This commit is contained in:
parent
ac91033ccb
commit
833490b16f
2 changed files with 25 additions and 14 deletions
11
ggml-metal.m
11
ggml-metal.m
|
@ -1194,6 +1194,14 @@ static bool ggml_metal_graph_compute(
|
|||
const float scale = ((float *) dst->op_params)[0];
|
||||
const float max_bias = ((float *) dst->op_params)[1];
|
||||
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
if (id_src1) {
|
||||
|
@ -1212,6 +1220,9 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
||||
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
||||
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
||||
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
||||
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
||||
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
|
|
@ -358,6 +358,9 @@ kernel void kernel_soft_max(
|
|||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
|
@ -377,15 +380,12 @@ kernel void kernel_soft_max(
|
|||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t n_head_kv = ne02;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
|
||||
|
||||
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const int64_t h = i02;
|
||||
|
||||
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
|
@ -462,6 +462,9 @@ kernel void kernel_soft_max_4(
|
|||
constant int64_t & ne02,
|
||||
constant float & scale,
|
||||
constant float & max_bias,
|
||||
constant float & m0,
|
||||
constant float & m1,
|
||||
constant uint32_t & n_head_log2,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
|
@ -480,15 +483,12 @@ kernel void kernel_soft_max_4(
|
|||
float slope = 0.0f;
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t n_head_kv = ne02;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2((float) n_head_kv));
|
||||
|
||||
const float m0 = pow(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = pow(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const int64_t h = i02;
|
||||
|
||||
slope = h < n_head_log2 ? pow(m0, h + 1) : pow(m1, 2*(h - n_head_log2) + 1);
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = pow(base, exp);
|
||||
}
|
||||
|
||||
// parallel max
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue