CUDA: fix 1D im2col, add tests (ggml/993)

This commit is contained in:
Johannes Gäßler 2024-10-18 09:24:44 +02:00 committed by Georgi Gerganov
parent c19af0acb1
commit 80273a306d
3 changed files with 34 additions and 9 deletions

View file

@ -3308,15 +3308,41 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
// test cases for 1D im2col
// im2col 1D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
for (int s0 : {1, 3}) {
for (int p0 : {0, 3}) {
for (int d0 : {1, 3}) {
test_cases.emplace_back(new test_im2col(
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1},
s0, 0, p0, 0, d0, 0, false));
}
}
}
// test cases for 2D im2col
// im2col 2D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
for (int s0 : {1, 3}) {
for (int s1 : {1, 3}) {
for (int p0 : {0, 3}) {
for (int p1 : {0, 3}) {
for (int d0 : {1, 3}) {
for (int d1 : {1, 3}) {
test_cases.emplace_back(new test_im2col(
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2},
s0, s1, p0, p1, d0, d1, true));
}
}
}
}
}
}
// extra tests for im2col 2D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true));