Add patch to test cases provided by @compilade; test for ssm_conv fails

This commit is contained in:
Jan Ploski 2024-06-02 19:20:17 +02:00
parent 25f9e65d3a
commit 64fbd320ef

View file

@ -1645,18 +1645,26 @@ struct test_leaky_relu : public test_case {
// GGML_OP_SSM_CONV // GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case { struct test_ssm_conv : public test_case {
const ggml_type type; const ggml_type type;
const int64_t d_conv;
const int64_t d_inner;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(type, 3, 1536, 4); return VARS_TO_STR5(type, d_conv, d_inner, n_seq_tokens, n_seqs);
} }
test_ssm_conv(ggml_type type = GGML_TYPE_F32) test_ssm_conv(ggml_type type = GGML_TYPE_F32,
: type(type) {} int64_t d_conv = 4,
int64_t d_inner = 1536,
int64_t n_seq_tokens = 7,
int64_t n_seqs = 2)
: type(type), d_conv(d_conv), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 3, 1536, 1); ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_conv - 1, d_inner, n_seqs);
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 1); ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
ggml_tensor * c = ggml_new_tensor_2d(ctx, type, 4, 1536); ggml_tensor * c = ggml_new_tensor_2d(ctx, type, d_conv, d_inner);
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c); ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c);
return out; return out;
} }
@ -1665,21 +1673,29 @@ struct test_ssm_conv : public test_case {
// GGML_OP_SSM_SCAN // GGML_OP_SSM_SCAN
struct test_ssm_scan : public test_case { struct test_ssm_scan : public test_case {
const ggml_type type; const ggml_type type;
const int64_t d_state;
const int64_t d_inner;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override { std::string vars() override {
return VARS_TO_STR4(type, 16, 1536, 2); return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
} }
test_ssm_scan(ggml_type type = GGML_TYPE_F32) test_ssm_scan(ggml_type type = GGML_TYPE_F32,
: type(type) {} int64_t d_state = 16,
int64_t d_inner = 1536,
int64_t n_seq_tokens = 7,
int64_t n_seqs = 2)
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 16, 1536, 1); ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_state, d_inner, n_seqs);
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 2); ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_2d(ctx, type, 1536, 2); ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, 16, 1536); ggml_tensor * A = ggml_new_tensor_2d(ctx, type, d_state, d_inner);
ggml_tensor * B = ggml_new_tensor_2d(ctx, type, 16, 2); ggml_tensor * B = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor_2d(ctx, type, 16, 2); ggml_tensor * C = ggml_new_tensor_3d(ctx, type, d_state, n_seq_tokens, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
return out; return out;
} }