feat: add GGML_UNARY_OP_ARGMAX
Metal kernel (ggml/1019)
* implemented argmax kernel * tpig -> tgpig * change to strides * contiguous assertions * kernel working and tested * argmax simd parallel implementation * added 2 new tests for argmax in test-backend-ops * cosmit * added 3 tests cases for perf eval * add test_argmax in make_test_cases_perf * Update test-backend-ops.cpp Co-authored-by: Diego Devesa <slarengh@gmail.com> --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
parent
667d70d170
commit
efb6ae9630
3 changed files with 92 additions and 6 deletions
|
@ -3460,14 +3460,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
|
||||
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
|
||||
|
||||
test_cases.emplace_back(new test_argmax());
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_count_equal());
|
||||
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438, 3, 1, 1}));
|
||||
|
||||
for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1}));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue