ggml : require mask when using ALiBi
ggml-ci
This commit is contained in:
parent
397b1f8f9d
commit
0faf92e74c
2 changed files with 10 additions and 1 deletions
9
ggml.c
9
ggml.c
|
@ -5657,6 +5657,10 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
GGML_ASSERT(mask);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (a->grad) {
|
if (a->grad) {
|
||||||
|
@ -6440,6 +6444,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
float max_bias) {
|
float max_bias) {
|
||||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
// TODO: check if vT can be multiplied by (k*qT)
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
|
|
||||||
if (mask) {
|
if (mask) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(mask->ne[2] == 1);
|
GGML_ASSERT(mask->ne[2] == 1);
|
||||||
|
@ -6449,6 +6454,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (max_bias > 0.0f) {
|
||||||
|
GGML_ASSERT(mask);
|
||||||
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (q->grad || k->grad || v->grad) {
|
if (q->grad || k->grad || v->grad) {
|
||||||
|
|
|
@ -2126,6 +2126,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
#endif
|
#endif
|
||||||
for (bool mask : {false, true}) {
|
for (bool mask : {false, true}) {
|
||||||
for (float max_bias : {0.0f, 8.0f}) {
|
for (float max_bias : {0.0f, 8.0f}) {
|
||||||
|
if (!mask && max_bias > 0.0f) continue;
|
||||||
for (float scale : {1.0f, 0.1f}) {
|
for (float scale : {1.0f, 0.1f}) {
|
||||||
for (int64_t ne0 : {16, 1024}) {
|
for (int64_t ne0 : {16, 1024}) {
|
||||||
for (int64_t ne1 : {16, 1024}) {
|
for (int64_t ne1 : {16, 1024}) {
|
||||||
|
@ -2139,7 +2140,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
|
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
|
|
||||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
|
||||||
|
|
||||||
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue