wipwipwiwpip

This commit is contained in:
Georgi Gerganov 2024-05-27 12:04:09 +03:00
parent fc59407efe
commit ddc59e8e0a
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 132 additions and 1 deletions

View file

@ -1561,6 +1561,56 @@ struct test_flash_attn_ext : public test_case {
}
};
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type_s;
const ggml_type type_x;
const ggml_type type_c;
const ggml_type type_sq;
const int64_t d_inner;
const int64_t d_conv;
const int64_t n_tokens;
const int64_t n_rs;
std::string vars() override {
return VARS_TO_STR8(type_s, type_x, type_c, type_sq, d_inner, d_conv, n_tokens, n_rs);
}
test_ssm_conv(ggml_type type_s = GGML_TYPE_F32,
ggml_type type_x = GGML_TYPE_F32,
ggml_type type_c = GGML_TYPE_F32,
ggml_type type_sq = GGML_TYPE_I32,
int64_t d_inner = 10,
int64_t d_conv = 10,
int64_t n_tokens = 10,
int64_t n_rs = 10)
: type_s(type_s), type_x(type_x), type_c(type_c), type_sq(type_sq), d_inner(d_inner), d_conv(d_conv), n_tokens(n_tokens), n_rs(n_rs) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * s = ggml_new_tensor_3d (ctx, type_s, d_conv-1, d_inner, n_rs);
ggml_tensor * x = ggml_new_tensor_2d (ctx, type_x, d_inner, n_tokens);
ggml_tensor * c = ggml_new_tensor_2d (ctx, type_c, d_conv, d_inner);
ggml_tensor * sq = ggml_new_tensor_1d(ctx, type_sq, n_tokens);
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c, sq);
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
// pos
std::vector<int> data(t->ne[0]);
for (int i = 0; i < t->ne[0]; i++) {
data[i] = rand() % n_rs;
}
ggml_backend_tensor_set(t, data.data(), 0, t->ne[0] * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
}
};
enum llm_norm_type {
LLM_NORM,
LLM_NORM_RMS,
@ -2246,6 +2296,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}
test_cases.emplace_back(new test_ssm_conv());
// these tests are disabled to save execution time, but they can be handy for debugging
#if 0
test_cases.emplace_back(new test_llama(1));