CPU/CUDA: fix (GQA) mul mat back, add CUDA support (#11380)

This commit is contained in:
Johannes Gäßler 2025-01-24 12:38:31 +01:00 committed by GitHub
parent 1af6945eb0
commit 8137b4bb2b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 156 additions and 61 deletions

View file

@ -1302,6 +1302,59 @@ struct test_repeat : public test_case {
}
};
// GGML_OP_REPEAT_BACK
struct test_repeat_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const std::array<int, 4> nr;
const bool v; // whether src is a noncontiguous view
std::string vars() override {
return VARS_TO_STR4(type, ne, nr, v);
}
size_t op_size(ggml_tensor * t) override {
return ggml_nbytes(t) * 2;
}
test_repeat_back(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {8, 6, 4, 2},
std::array<int, 4> nr = {2, 2, 2, 2},
bool v = false)
: type(type), ne(ne), nr(nr), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
ggml_set_name(src, "src");
if (v) {
GGML_ASSERT(ne[0] % 2 == 0);
GGML_ASSERT(ne[1] % 2 == 0);
GGML_ASSERT(ne[2] % 2 == 0);
GGML_ASSERT(ne[3] % 2 == 0);
GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
}
ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(target, "target");
ggml_tensor * out = ggml_repeat_back(ctx, src, target);
ggml_set_name(out, "out");
return out;
}
};
// GGML_OP_DUP
struct test_dup : public test_case {
const ggml_type type;
@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case {
return 5e-4;
}
int64_t grad_nmax() override {
return 20000;
}
uint64_t op_flops(ggml_tensor * t) override {
GGML_UNUSED(t);
return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case {
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(ctx, a);
}
ggml_set_param(ctx, b);
}
ggml_set_name(a, "a");
ggml_set_name(b, "b");
@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case {
} else {
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
ggml_set_param(ctx, a);
ggml_set_param(ctx, b);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
ggml_set_param(ctx, a);
}
ggml_set_param(ctx, b);
}
ggml_set_name(a, "a");
ggml_set_name(b, "b");
}
@ -3798,6 +3863,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
}
for (bool view : {false, true}) {
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
}
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
@ -3919,21 +3994,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
// test cases without permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
// test cases with permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));