ggml : add ALiBi support for ggml_soft_max_ext (#5488)

* ggml : avoid recomputing alibi slopes (CPU)

* llama : reuse hparams.f_max_alibi_bias in all cases

ggml-ci

* ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal)

ggml-ci

* ggml : handle all SRCs (do not break on first null)

ggml-ci

* tests : do not use slope for large soft_max

accumulates too much error

ggml-ci

* ggml : alternative ALiBi without extra tensor

We compute the slopes in the kernel

ggml-ci

* cuda : add ALiBi support in ggml_soft_max_ext

ggml-ci

* ggml : deprecate ggml_alibi

* ggml : support multi-sequence ALiBi (Metal)

ggml-ci

* cuda : add multi-seq ALiBi + remote F16 soft_max

ggml-ci

* ggml : update deprecation message

* ggml : fix pos ptr when no ALiBi

ggml-ci

* cuda : fix performance (pow -> powf)

* cuda : precompute ALiBi constants

* metal : pre-compute ALiBi slopes

ggml-ci

* llama : init kq_pos only if needed

ggml-ci

* test-backend-ops : add null pos test to soft_max

test-backend-ops : replace soft_max tests

ggml-ci

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2024-02-17 23:04:16 +02:00 committed by GitHub
parent 6e4e973b26
commit 8f1be0d42f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 348 additions and 357 deletions

View file

@ -1085,24 +1085,32 @@ struct test_diag_mask_inf : public test_case {
struct test_soft_max : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const float scale;
const bool mask;
const float scale;
const float max_bias;
std::string vars() override {
return VARS_TO_STR4(type, ne, scale, mask);
return VARS_TO_STR5(type, ne, mask, scale, max_bias);
}
test_soft_max(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10},
bool mask = false,
float scale = 1.0f,
bool mask = false)
: type(type), ne(ne), scale(scale), mask(mask) {}
float max_bias = 0.0f)
: type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = nullptr;
if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
ggml_tensor * mask = nullptr;
if (this->mask) {
mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
}
ggml_tensor * pos = nullptr;
if (max_bias > 0.0f) {
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
}
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
return out;
}
};
@ -1147,30 +1155,6 @@ struct test_rope : public test_case {
}
};
// GGML_OP_ALIBI
struct test_alibi : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
int n_past;
int n_head;
float bias_max;
std::string vars() override {
return VARS_TO_STR5(type, ne, n_past, n_head, bias_max);
}
test_alibi(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10},
int n_past = 512, int n_head = 10, float bias_max = 0.5f)
: type(type), ne(ne), n_past(n_past), n_head(n_head), bias_max(bias_max) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * out = ggml_alibi(ctx, a, n_past, n_head, bias_max);
return out;
}
};
// GGML_OP_POOL2D
struct test_pool2d : public test_case {
enum ggml_op_pool pool_type;
@ -1488,7 +1472,7 @@ struct test_moe : public test_case {
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd));
ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd), 0.0f);
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
@ -1617,7 +1601,6 @@ public:
ggml_cpy(ctx, v_cur_t, v_cache_view);
}
// if max_alibi_bias > 0 then apply ALiBi
struct ggml_tensor * llm_build_kqv(
struct ggml_context * ctx,
struct ggml_tensor * k_l,
@ -1636,7 +1619,7 @@ public:
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f);
// split cached v into n_head heads
struct ggml_tensor * v =
@ -2083,6 +2066,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 1}, 5));
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
#if 0
std::uniform_int_distribution<> dist_ne1(1, 50);
int exponent = 1;
while (exponent < (1 << 17)) {
@ -2091,14 +2075,29 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (int n = 0; n < 10; ++n) {
int64_t ne0 = dist_ne0(rng);
int64_t ne1 = dist_ne1(rng);
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
}
exponent <<= 1;
}
#endif
for (bool mask : {false, true}) {
for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
for (int64_t ne0 : {16, 1024}) {
for (int64_t ne1 : {16, 1024}) {
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
}
}
}
}
}
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
@ -2113,7 +2112,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
}
test_cases.emplace_back(new test_alibi());
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
test_cases.emplace_back(new test_concat(GGML_TYPE_I32));