wipwipwiwpip
This commit is contained in:
parent
fc59407efe
commit
ddc59e8e0a
4 changed files with 132 additions and 1 deletions
|
@ -1561,6 +1561,56 @@ struct test_flash_attn_ext : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SSM_CONV
|
||||
struct test_ssm_conv : public test_case {
|
||||
const ggml_type type_s;
|
||||
const ggml_type type_x;
|
||||
const ggml_type type_c;
|
||||
const ggml_type type_sq;
|
||||
const int64_t d_inner;
|
||||
const int64_t d_conv;
|
||||
const int64_t n_tokens;
|
||||
const int64_t n_rs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR8(type_s, type_x, type_c, type_sq, d_inner, d_conv, n_tokens, n_rs);
|
||||
}
|
||||
|
||||
test_ssm_conv(ggml_type type_s = GGML_TYPE_F32,
|
||||
ggml_type type_x = GGML_TYPE_F32,
|
||||
ggml_type type_c = GGML_TYPE_F32,
|
||||
ggml_type type_sq = GGML_TYPE_I32,
|
||||
int64_t d_inner = 10,
|
||||
int64_t d_conv = 10,
|
||||
int64_t n_tokens = 10,
|
||||
int64_t n_rs = 10)
|
||||
: type_s(type_s), type_x(type_x), type_c(type_c), type_sq(type_sq), d_inner(d_inner), d_conv(d_conv), n_tokens(n_tokens), n_rs(n_rs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * s = ggml_new_tensor_3d (ctx, type_s, d_conv-1, d_inner, n_rs);
|
||||
ggml_tensor * x = ggml_new_tensor_2d (ctx, type_x, d_inner, n_tokens);
|
||||
ggml_tensor * c = ggml_new_tensor_2d (ctx, type_c, d_conv, d_inner);
|
||||
ggml_tensor * sq = ggml_new_tensor_1d(ctx, type_sq, n_tokens);
|
||||
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c, sq);
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
// pos
|
||||
std::vector<int> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = rand() % n_rs;
|
||||
}
|
||||
ggml_backend_tensor_set(t, data.data(), 0, t->ne[0] * sizeof(int));
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
enum llm_norm_type {
|
||||
LLM_NORM,
|
||||
LLM_NORM_RMS,
|
||||
|
@ -2246,6 +2296,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_ssm_conv());
|
||||
|
||||
// these tests are disabled to save execution time, but they can be handy for debugging
|
||||
#if 0
|
||||
test_cases.emplace_back(new test_llama(1));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue