metal : works with ne00 % 4 == 0

This commit is contained in:
Georgi Gerganov 2024-02-08 13:26:50 +02:00
parent e68e32548f
commit 845876d012
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 31 additions and 18 deletions

View file

@ -2076,16 +2076,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}
#else
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 1, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 2, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 3, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 4, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 5, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 6, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 7, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 8, 4096, { 1, 1}, {1, 1}));
for (int r0 = 0; r0 < 32; ++r0) {
for (int c0 = 0; c0 < 4096; c0 += 512) {
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 1, 64 + c0, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 2, 64 + c0, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 3, 64 + c0, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 4, 64 + c0, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 5, 64 + c0, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 6, 64 + c0, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 7, 64 + c0, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 64 + r0, 8, 64 + c0, { 1, 1}, {1, 1}));
}
}
}
}
#endif