CUDA: non-contiguous (RMS) norm support (#11659)

* CUDA: non-contiguous (RMS) norm support

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Johannes Gäßler 2025-02-04 22:21:42 +01:00 committed by GitHub
parent 3ec9fd4b77
commit fd08255d0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 97 additions and 47 deletions

View file

@ -1674,21 +1674,28 @@ struct test_silu_back : public test_case {
struct test_norm : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
const bool v; // whether a is a non-contiguous view
const float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
return VARS_TO_STR4(type, ne, v, eps);
}
test_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
bool v = false,
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
: type(type), ne(ne), v(v), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
}
ggml_tensor * out = ggml_norm(ctx, a, eps);
ggml_set_name(out, "out");
@ -1700,22 +1707,29 @@ struct test_norm : public test_case {
struct test_rms_norm : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
const bool v; // whether a is a non-contiguous view
const float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
return VARS_TO_STR4(type, ne, v, eps);
}
test_rms_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
bool v = false,
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
: type(type), ne(ne), v(v), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_name(a, "a");
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
}
ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
ggml_set_name(out, "out");
@ -1741,7 +1755,7 @@ struct test_rms_norm : public test_case {
struct test_rms_norm_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
const float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
@ -2919,7 +2933,7 @@ struct test_group_norm : public test_case {
const float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, num_groups);
return VARS_TO_STR4(type, ne, num_groups, eps);
}
test_group_norm(ggml_type type = GGML_TYPE_F32,
@ -3964,9 +3978,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_silu_back());
for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) {
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
}
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}