cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X (#8800)

* cuda : fix dmmv cols requirement to 2*GGML_CUDA_DMMV_X

* update asserts

* only use dmmv for supported types

* add test
This commit is contained in:
slaren 2024-08-01 15:26:22 +02:00 committed by GitHub
parent c8a0090922
commit 7a11eb3a26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 22 additions and 11 deletions

View file

@ -804,8 +804,7 @@ struct test_cpy : public test_case {
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 1},
std::array<int64_t, 4> permute = {0, 0, 0, 0},
bool _dst_use_permute = false)
std::array<int64_t, 4> permute = {0, 0, 0, 0})
: type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
_src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
@ -2269,6 +2268,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
for (ggml_type type_a : other_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, ggml_blck_size(type_a), { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
}
}