ggml : ggml_flash_attn_ext() support ALiBi (Metal)
This commit is contained in:
parent
166e60bf9b
commit
97c27f59f6
4 changed files with 105 additions and 47 deletions
34
ggml-metal.m
34
ggml-metal.m
|
@ -1390,8 +1390,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
const int64_t nrows_y = src0->ne[1];
|
const int64_t nrows_y = src0->ne[1];
|
||||||
|
|
||||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
const uint32_t n_head = nrows_x/nrows_y;
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
@ -2513,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||||
|
|
||||||
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
||||||
const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
||||||
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
||||||
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
||||||
|
|
||||||
|
@ -2525,7 +2525,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
||||||
|
|
||||||
float scale;
|
float scale;
|
||||||
memcpy(&scale, dst->op_params, sizeof(float));
|
float max_bias;
|
||||||
|
|
||||||
|
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||||
|
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||||
|
|
||||||
|
const uint32_t n_head = src0->ne[2];
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
@ -2583,13 +2592,16 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
|
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
|
||||||
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
|
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
|
||||||
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
|
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
|
||||||
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
|
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21];
|
||||||
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
|
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22];
|
||||||
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23];
|
||||||
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24];
|
||||||
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25];
|
||||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:26];
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
[encoder setBytes:&max_bias length:sizeof( float) atIndex:27];
|
||||||
|
[encoder setBytes:&m0 length:sizeof(m0) atIndex:28];
|
||||||
|
[encoder setBytes:&m1 length:sizeof(m1) atIndex:29];
|
||||||
|
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];
|
||||||
|
|
||||||
if (!use_vec_kernel) {
|
if (!use_vec_kernel) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
|
|
|
@ -2058,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)(
|
||||||
constant uint64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant uint64_t & nb13,
|
constant uint64_t & nb13,
|
||||||
constant int64_t & ne31,
|
|
||||||
constant uint64_t & nb31,
|
constant uint64_t & nb31,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & ne2,
|
constant int64_t & ne2,
|
||||||
constant int64_t & ne3,
|
constant int64_t & ne3,
|
||||||
constant float & scale,
|
constant float & scale,
|
||||||
|
constant float & max_bias,
|
||||||
|
constant float & m0,
|
||||||
|
constant float & m1,
|
||||||
|
constant uint32_t & n_head_log2,
|
||||||
threadgroup half * shared,
|
threadgroup half * shared,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
@ -2096,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
constant uint64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant uint64_t & nb13,
|
constant uint64_t & nb13,
|
||||||
constant int64_t & ne31,
|
|
||||||
constant uint64_t & nb31,
|
constant uint64_t & nb31,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & ne2,
|
constant int64_t & ne2,
|
||||||
constant int64_t & ne3,
|
constant int64_t & ne3,
|
||||||
constant float & scale,
|
constant float & scale,
|
||||||
|
constant float & max_bias,
|
||||||
|
constant float & m0,
|
||||||
|
constant float & m1,
|
||||||
|
constant uint32_t & n_head_log2,
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
@ -2199,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
// prepare diagonal scale matrix
|
// prepare diagonal scale matrix
|
||||||
simdgroup_float8x8 mscale(scale);
|
simdgroup_float8x8 mscale(scale);
|
||||||
|
|
||||||
|
// prepare diagonal slope matrix
|
||||||
|
simdgroup_float8x8 mslope(1.0f);
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const short h = iq2;
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
mslope = simdgroup_float8x8(pow(base, exph));
|
||||||
|
}
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
|
||||||
|
@ -2221,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16(
|
||||||
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mqk = mqk*scale + mask
|
// mqk = mqk*scale + mask*slope
|
||||||
simdgroup_half8x8 mm;
|
simdgroup_half8x8 mm;
|
||||||
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
|
||||||
|
simdgroup_multiply(mm, mslope, mm);
|
||||||
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
|
||||||
|
|
||||||
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
||||||
|
@ -2414,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
constant uint64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant uint64_t & nb13,
|
constant uint64_t & nb13,
|
||||||
constant int64_t & ne31,
|
|
||||||
constant uint64_t & nb31,
|
constant uint64_t & nb31,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant int64_t & ne2,
|
constant int64_t & ne2,
|
||||||
constant int64_t & ne3,
|
constant int64_t & ne3,
|
||||||
constant float & scale,
|
constant float & scale,
|
||||||
|
constant float & max_bias,
|
||||||
|
constant float & m0,
|
||||||
|
constant float & m1,
|
||||||
|
constant uint32_t & n_head_log2,
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
@ -2439,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
|
|
||||||
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
|
||||||
|
|
||||||
|
float slope = 1.0f;
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
const short h = iq2;
|
||||||
|
|
||||||
|
const float base = h < n_head_log2 ? m0 : m1;
|
||||||
|
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||||
|
|
||||||
|
slope = pow(base, exp);
|
||||||
|
}
|
||||||
|
|
||||||
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||||
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
@ -2545,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
mqk += simd_shuffle_down(mqk, 2);
|
mqk += simd_shuffle_down(mqk, 2);
|
||||||
mqk += simd_shuffle_down(mqk, 1);
|
mqk += simd_shuffle_down(mqk, 1);
|
||||||
|
|
||||||
// mqk = mqk*scale + mask
|
// mqk = mqk*scale + mask*slope
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
float4 mm = (float4) mp4[ic/4 + cc];
|
float4 mm = (float4) mp4[ic/4 + cc];
|
||||||
mqk = mqk*scale + mm;
|
mqk = mqk*scale + mm*slope;
|
||||||
|
|
||||||
ss4[cc] = mqk;
|
ss4[cc] = mqk;
|
||||||
}
|
}
|
||||||
|
@ -2782,7 +2817,8 @@ kernel void kernel_cpy_f32_f16(
|
||||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
|
|
||||||
dst_data[i00] = src[0];
|
// TODO: is there a better way to handle -INFINITY?
|
||||||
|
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10985,6 +10985,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||||
|
for (int j = 0; j < n_kv; ++j) {
|
||||||
|
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// when using kv cache, the mask needs to match the kv cache size
|
// when using kv cache, the mask needs to match the kv cache size
|
||||||
|
|
|
@ -1486,23 +1486,25 @@ struct test_flash_attn_ext : public test_case {
|
||||||
const int64_t kv; // kv size
|
const int64_t kv; // kv size
|
||||||
const int64_t nb; // batch size
|
const int64_t nb; // batch size
|
||||||
|
|
||||||
|
const float max_bias; // ALiBi
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR4(hs, nh, kv, nb);
|
return VARS_TO_STR5(hs, nh, kv, nb, max_bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
double max_nmse_err() override {
|
double max_nmse_err() override {
|
||||||
return 5e-4;
|
return 5e-4;
|
||||||
}
|
}
|
||||||
|
|
||||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f)
|
||||||
: hs(hs), nh(nh), kv(kv), nb(nb) {}
|
: hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
||||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
|
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
|
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2176,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
#else
|
#else
|
||||||
for (int hs : { 64, 80, 128, 256, }) {
|
for (int hs : { 64, 80, 128, 256, }) {
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
for (float max_bias : {0.0f, 8.0f}) {
|
||||||
for (int nh : { 32, }) {
|
for (int nh : { 32, }) {
|
||||||
for (int kv : { 512, 1024, }) {
|
for (int kv : { 512, 1024, }) {
|
||||||
for (int nb : { 1, 2, 4, 8, }) {
|
for (int nb : { 1, 2, 4, 8, }) {
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
|
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue