llama : avoid redundant state copy for Mamba 1 and 2

This commit is contained in:
Francis Couture-Harpin 2024-09-30 15:52:42 -04:00
parent 0e601cafe9
commit 273e7a495a
4 changed files with 142 additions and 119 deletions

View file

@ -1530,27 +1530,58 @@ struct test_ssm_scan : public test_case {
const int64_t d_state;
const int64_t d_inner;
const int64_t n_head;
const int64_t n_group;
const int64_t n_seq_tokens;
const int64_t n_seqs;
std::string vars() override {
return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs);
}
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
int64_t d_state = 32,
int64_t d_inner = 1, // non-zero for Mamba-2
int64_t n_head = 32,
int64_t n_group = 1,
int64_t n_seq_tokens = 32,
int64_t n_seqs = 32)
: type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data());
ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data());
ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs);
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs);
ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs);
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head);
ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs);
ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head);
ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs);
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids);
return out;
}
// similar to test_mul_mat_id
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
// ids
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<int32_t> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
}
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_MUL_MAT
@ -3255,7 +3286,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2
#if 1
for (ggml_type type_a : base_types) {