diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ee1ee61ae..592656048 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1645,18 +1645,26 @@ struct test_leaky_relu : public test_case { // GGML_OP_SSM_CONV struct test_ssm_conv : public test_case { const ggml_type type; + const int64_t d_conv; + const int64_t d_inner; + const int64_t n_seq_tokens; + const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR4(type, 3, 1536, 4); + return VARS_TO_STR5(type, d_conv, d_inner, n_seq_tokens, n_seqs); } - test_ssm_conv(ggml_type type = GGML_TYPE_F32) - : type(type) {} + test_ssm_conv(ggml_type type = GGML_TYPE_F32, + int64_t d_conv = 4, + int64_t d_inner = 1536, + int64_t n_seq_tokens = 7, + int64_t n_seqs = 2) + : type(type), d_conv(d_conv), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 3, 1536, 1); - ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 1); - ggml_tensor * c = ggml_new_tensor_2d(ctx, type, 4, 1536); + ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_conv - 1, d_inner, n_seqs); + ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * c = ggml_new_tensor_2d(ctx, type, d_conv, d_inner); ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c); return out; } @@ -1665,21 +1673,29 @@ struct test_ssm_conv : public test_case { // GGML_OP_SSM_SCAN struct test_ssm_scan : public test_case { const ggml_type type; + const int64_t d_state; + const int64_t d_inner; + const int64_t n_seq_tokens; + const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR4(type, 16, 1536, 2); + return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); } - test_ssm_scan(ggml_type type = GGML_TYPE_F32) - : type(type) {} + test_ssm_scan(ggml_type type = GGML_TYPE_F32, + int64_t d_state = 16, + int64_t d_inner = 1536, + int64_t n_seq_tokens = 7, + int64_t n_seqs = 2) + : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 16, 1536, 1); - ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 2); - ggml_tensor * dt = ggml_new_tensor_2d(ctx, type, 1536, 2); - ggml_tensor * A = ggml_new_tensor_2d(ctx, type, 16, 1536); - ggml_tensor * B = ggml_new_tensor_2d(ctx, type, 16, 2); - ggml_tensor * C = ggml_new_tensor_2d(ctx, type, 16, 2); + ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_state, d_inner, n_seqs); + ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, d_state, d_inner); + ggml_tensor * B = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs); ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); return out; }