ggml : add SSM Metal kernels (#8546)
* ggml : add ggml_ssm_conv metal impl * ggml : add ssm_scan metal impl ggml-ci
This commit is contained in:
parent
879275ac98
commit
fc18425b6a
4 changed files with 303 additions and 2 deletions
|
@ -949,6 +949,58 @@ struct test_rms_norm : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SSM_CONV
|
||||
struct test_ssm_conv : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
const std::array<int64_t, 4> ne_b;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR3(type, ne_a, ne_b);
|
||||
}
|
||||
|
||||
test_ssm_conv(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
|
||||
std::array<int64_t, 4> ne_b = {3, 3, 1, 1})
|
||||
: type(type), ne_a(ne_a), ne_b(ne_b) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
|
||||
ggml_tensor * out = ggml_ssm_conv(ctx, a, b);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SSM_SCAN
|
||||
struct test_ssm_scan : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t d_state;
|
||||
const int64_t d_inner;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
||||
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data());
|
||||
ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
|
||||
ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
|
||||
ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data());
|
||||
ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
|
||||
ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
|
||||
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_MUL_MAT
|
||||
struct test_mul_mat : public test_case {
|
||||
const ggml_type type_a;
|
||||
|
@ -2240,6 +2292,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
|
||||
|
||||
#if 1
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue