Merge branch 'master' into gg/flash-attn
This commit is contained in:
commit
02a645e7b7
94 changed files with 12943 additions and 4640 deletions
|
@ -1114,11 +1114,11 @@ struct test_soft_max : public test_case {
|
|||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * mask = nullptr;
|
||||
if (this->mask) {
|
||||
mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
|
||||
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]);
|
||||
}
|
||||
ggml_tensor * pos = nullptr;
|
||||
if (max_bias > 0.0f) {
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, ne[0]);
|
||||
}
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
|
||||
return out;
|
||||
|
@ -1274,7 +1274,7 @@ struct test_argsort : public test_case {
|
|||
|
||||
test_argsort(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {16, 10, 10, 10},
|
||||
ggml_sort_order order = GGML_SORT_ASC)
|
||||
ggml_sort_order order = GGML_SORT_ORDER_ASC)
|
||||
: type(type), ne(ne), order(order) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
|
@ -1996,8 +1996,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
|
||||
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
|
||||
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S,
|
||||
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
|
||||
};
|
||||
|
||||
// unary ops
|
||||
|
@ -2195,7 +2196,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
|
||||
test_cases.emplace_back(new test_concat(GGML_TYPE_I32));
|
||||
|
||||
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
|
||||
for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
|
||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue