ggml : ggml_flash_attn_ext() support ALiBi (Metal)

This commit is contained in:
Georgi Gerganov 2024-05-10 13:51:00 +03:00
parent 166e60bf9b
commit 97c27f59f6
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 105 additions and 47 deletions

View file

@ -1390,8 +1390,8 @@ static enum ggml_status ggml_metal_graph_compute(
const int64_t nrows_x = ggml_nrows(src0); const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1]; const int64_t nrows_y = src0->ne[1];
const uint32_t n_head_kv = nrows_x/nrows_y; const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@ -2513,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute(
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
const int64_t ne31 = src3 ? src3->ne[1] : 0; //const int64_t ne31 = src3 ? src3->ne[1] : 0;
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
@ -2525,7 +2525,16 @@ static enum ggml_status ggml_metal_graph_compute(
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
float scale; float scale;
memcpy(&scale, dst->op_params, sizeof(float)); float max_bias;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
id<MTLComputePipelineState> pipeline = nil; id<MTLComputePipelineState> pipeline = nil;
@ -2562,34 +2571,37 @@ static enum ggml_status ggml_metal_graph_compute(
} }
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
[encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21];
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:26];
[encoder setBytes:&scale length:sizeof( float) atIndex:27]; [encoder setBytes:&max_bias length:sizeof( float) atIndex:27];
[encoder setBytes:&m0 length:sizeof(m0) atIndex:28];
[encoder setBytes:&m1 length:sizeof(m1) atIndex:29];
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];
if (!use_vec_kernel) { if (!use_vec_kernel) {
// half8x8 kernel // half8x8 kernel

View file

@ -2058,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)(
constant uint64_t & nb11, constant uint64_t & nb11,
constant uint64_t & nb12, constant uint64_t & nb12,
constant uint64_t & nb13, constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31, constant uint64_t & nb31,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant int64_t & ne2, constant int64_t & ne2,
constant int64_t & ne3, constant int64_t & ne3,
constant float & scale, constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup half * shared, threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
@ -2096,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16(
constant uint64_t & nb11, constant uint64_t & nb11,
constant uint64_t & nb12, constant uint64_t & nb12,
constant uint64_t & nb13, constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31, constant uint64_t & nb31,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant int64_t & ne2, constant int64_t & ne2,
constant int64_t & ne3, constant int64_t & ne3,
constant float & scale, constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]], threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
@ -2199,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16(
// prepare diagonal scale matrix // prepare diagonal scale matrix
simdgroup_float8x8 mscale(scale); simdgroup_float8x8 mscale(scale);
// prepare diagonal slope matrix
simdgroup_float8x8 mslope(1.0f);
// ALiBi
if (max_bias > 0.0f) {
const short h = iq2;
const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
mslope = simdgroup_float8x8(pow(base, exph));
}
// loop over the KV cache // loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns // each simdgroup handles blocks of Q rows and C columns
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@ -2221,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
} }
// mqk = mqk*scale + mask // mqk = mqk*scale + mask*slope
simdgroup_half8x8 mm; simdgroup_half8x8 mm;
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply(mm, mslope, mm);
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
simdgroup_store(mqk, ss + 8*cc, TF, 0, false); simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
@ -2414,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
constant uint64_t & nb11, constant uint64_t & nb11,
constant uint64_t & nb12, constant uint64_t & nb12,
constant uint64_t & nb13, constant uint64_t & nb13,
constant int64_t & ne31,
constant uint64_t & nb31, constant uint64_t & nb31,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant int64_t & ne2, constant int64_t & ne2,
constant int64_t & ne3, constant int64_t & ne3,
constant float & scale, constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup half * shared [[threadgroup(0)]], threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
@ -2439,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short T = D + 2*nsg*SH; // shared memory size per query in (half) const short T = D + 2*nsg*SH; // shared memory size per query in (half)
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const short h = iq2;
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
@ -2545,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
mqk += simd_shuffle_down(mqk, 2); mqk += simd_shuffle_down(mqk, 2);
mqk += simd_shuffle_down(mqk, 1); mqk += simd_shuffle_down(mqk, 1);
// mqk = mqk*scale + mask // mqk = mqk*scale + mask*slope
if (tiisg == 0) { if (tiisg == 0) {
float4 mm = (float4) mp4[ic/4 + cc]; float4 mm = (float4) mp4[ic/4 + cc];
mqk = mqk*scale + mm; mqk = mqk*scale + mm*slope;
ss4[cc] = mqk; ss4[cc] = mqk;
} }
@ -2782,7 +2817,8 @@ kernel void kernel_cpy_f32_f16(
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0]; // TODO: is there a better way to handle -INFINITY?
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
} }
} }

View file

@ -10985,6 +10985,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = f; data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
} }
} }
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_kv; ++j) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
}
}
} }
} else { } else {
// when using kv cache, the mask needs to match the kv cache size // when using kv cache, the mask needs to match the kv cache size

View file

@ -1486,23 +1486,25 @@ struct test_flash_attn_ext : public test_case {
const int64_t kv; // kv size const int64_t kv; // kv size
const int64_t nb; // batch size const int64_t nb; // batch size
const float max_bias; // ALiBi
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(hs, nh, kv, nb); return VARS_TO_STR5(hs, nh, kv, nb, max_bias);
} }
double max_nmse_err() override { double max_nmse_err() override {
return 5e-4; return 5e-4;
} }
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f)
: hs(hs), nh(nh), kv(kv), nb(nb) {} : hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias);
return out; return out;
} }
}; };
@ -2176,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
#else #else
for (int hs : { 64, 80, 128, 256, }) { for (int hs : { 64, 80, 128, 256, }) {
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
for (int nh : { 32, }) { for (float max_bias : {0.0f, 8.0f}) {
for (int kv : { 512, 1024, }) { for (int nh : { 32, }) {
for (int nb : { 1, 2, 4, 8, }) { for (int kv : { 512, 1024, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); for (int nb : { 1, 2, 4, 8, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias));
}
} }
} }
} }