ggml : fa without mask + add asserts

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-14 14:00:43 +03:00
parent 541600201e
commit a2e6b9dee1
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
5 changed files with 49 additions and 23 deletions

View file

@ -2513,12 +2513,14 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ne11 % 32 == 0);
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
GGML_ASSERT(ggml_are_same_shape (src1, src2)); GGML_ASSERT(ggml_are_same_shape (src1, src2));
GGML_ASSERT(src3); GGML_ASSERT(ggml_are_same_stride(src1, src2));
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
size_t offs_src3 = 0; size_t offs_src3 = 0;
@ -2590,7 +2592,11 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
if (id_src3) {
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];

View file

@ -2247,11 +2247,16 @@ kernel void kernel_flash_attn_ext_f16(
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
} }
if (mask != q) {
// mqk = mqk*scale + mask*slope // mqk = mqk*scale + mask*slope
simdgroup_half8x8 mm; simdgroup_half8x8 mm;
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false); simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
simdgroup_multiply(mm, mslope, mm); simdgroup_multiply(mm, mslope, mm);
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm); simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
} else {
// mqk = mqk*scale
simdgroup_multiply(mqk, mscale, mqk);
}
simdgroup_store(mqk, ss + 8*cc, TF, 0, false); simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
} }
@ -2589,8 +2594,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
// mqk = mqk*scale + mask*slope // mqk = mqk*scale + mask*slope
if (tiisg == 0) { if (tiisg == 0) {
float4 mm = (float4) mp4[ic/4 + cc]; mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
mqk = mqk*scale + mm*slope;
ss4[cc] = mqk; ss4[cc] = mqk;
} }

10
ggml.c
View file

@ -2822,6 +2822,16 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
(t0->ne[3] == t1->ne[3] ); (t0->ne[3] == t1->ne[3] );
} }
bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
(t0->nb[0] == t1->nb[0] ) &&
(t0->nb[1] == t1->nb[1] ) &&
(t0->nb[2] == t1->nb[2] ) &&
(t0->nb[3] == t1->nb[3] );
}
// check if t1 can be represented as a repeatition of t0 // check if t1 can be represented as a repeatition of t0
static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");

1
ggml.h
View file

@ -767,6 +767,7 @@ extern "C" {
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
// use this to compute the memory overhead of a tensor // use this to compute the memory overhead of a tensor
GGML_API size_t ggml_tensor_overhead(void); GGML_API size_t ggml_tensor_overhead(void);

View file

@ -1487,25 +1487,27 @@ struct test_flash_attn_ext : public test_case {
const int64_t kv; // kv size const int64_t kv; // kv size
const int64_t nb; // batch size const int64_t nb; // batch size
const bool mask; // use mask
const float max_bias; // ALiBi const float max_bias; // ALiBi
std::string vars() override { std::string vars() override {
return VARS_TO_STR5(hs, nh, kv, nb, max_bias); return VARS_TO_STR6(hs, nh, kv, nb, mask, max_bias);
} }
double max_nmse_err() override { double max_nmse_err() override {
return 5e-4; return 5e-4;
} }
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f) test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f)
: hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {} : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias); ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
return out; return out;
} }
}; };
@ -2175,11 +2177,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_leaky_relu()); test_cases.emplace_back(new test_leaky_relu());
for (int hs : { 64, 80, 128, 256, }) { for (int hs : { 64, 80, 128, 256, }) {
for (bool mask : { true, false } ) {
for (float max_bias : { 0.0f, 8.0f }) { for (float max_bias : { 0.0f, 8.0f }) {
if (!mask && max_bias > 0.0f) continue;
for (int nh : { 32, }) { for (int nh : { 32, }) {
for (int kv : { 512, 1024, }) { for (int kv : { 512, 1024, }) {
for (int nb : { 1, 2, 4, 8, }) { for (int nb : { 1, 2, 4, 8, }) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias)); test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias));
}
} }
} }
} }