ggml : avoid multiply by D in GGML_OP_SSM_SCAN

This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
This commit is contained in:
Francis Couture-Harpin 2024-11-04 11:36:37 -05:00
parent 7d16e1bc8c
commit 3bc7103d2e
7 changed files with 98 additions and 95 deletions

View file

@ -1589,35 +1589,34 @@ struct test_ssm_scan : public test_case {
const ggml_type type;
const int64_t d_state;
const int64_t d_inner;
const int64_t head_dim;
const int64_t n_head;
const int64_t n_group;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs);
return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs);
}
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
int64_t d_state = 32,
int64_t d_inner = 1, // non-zero for Mamba-2
int64_t head_dim = 1, // non-zero for Mamba-2
int64_t n_head = 32,
int64_t n_group = 1,
int64_t n_seq_tokens = 32,
int64_t n_seqs = 32)
: type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
: type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs);
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head);
ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head);
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids);
ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs);
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head);
ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids);
return out;
}