ggml : full ALiBi support (#7192)
* ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * ggml : fix assert message * vulkan : add dev notes * ggml : require mask when using ALiBi ggml-ci * convert : fix convert for refact models
This commit is contained in:
parent
e849648888
commit
9cb317f77e
16 changed files with 350 additions and 825 deletions
|
@ -1111,11 +1111,7 @@ struct test_soft_max : public test_case {
|
|||
if (this->mask) {
|
||||
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
|
||||
}
|
||||
ggml_tensor * pos = nullptr;
|
||||
if (max_bias > 0.0f) {
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
|
||||
}
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
@ -1490,23 +1486,25 @@ struct test_flash_attn_ext : public test_case {
|
|||
const int64_t kv; // kv size
|
||||
const int64_t nb; // batch size
|
||||
|
||||
const float max_bias; // ALiBi
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(hs, nh, kv, nb);
|
||||
return VARS_TO_STR5(hs, nh, kv, nb, max_bias);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 5e-4;
|
||||
}
|
||||
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
||||
: hs(hs), nh(nh), kv(kv), nb(nb) {}
|
||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f)
|
||||
: hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
@ -1611,7 +1609,7 @@ public:
|
|||
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
|
||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f);
|
||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
|
||||
|
||||
// split cached v into n_head heads
|
||||
struct ggml_tensor * v =
|
||||
|
@ -2128,6 +2126,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
#endif
|
||||
for (bool mask : {false, true}) {
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (float scale : {1.0f, 0.1f}) {
|
||||
for (int64_t ne0 : {16, 1024}) {
|
||||
for (int64_t ne1 : {16, 1024}) {
|
||||
|
@ -2141,7 +2140,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
||||
|
||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
|
@ -2180,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
#else
|
||||
for (int hs : { 64, 80, 128, 256, }) {
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
for (int nh : { 32, }) {
|
||||
for (int kv : { 512, 1024, }) {
|
||||
for (int nb : { 1, 2, 4, 8, }) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
for (int nh : { 32, }) {
|
||||
for (int kv : { 512, 1024, }) {
|
||||
for (int nb : { 1, 2, 4, 8, }) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue