Add GGML_OP_SSM_CONF, GGML_OP_SSM_SCAN to supported ops for CUDA backend + test case for each op
This commit is contained in:
parent
f809568fa1
commit
cc365b045b
2 changed files with 74 additions and 0 deletions
|
@ -2885,6 +2885,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_ARANGE:
|
case GGML_OP_ARANGE:
|
||||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_SSM_CONV:
|
||||||
|
case GGML_OP_SSM_SCAN:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
|
|
@ -1642,6 +1642,76 @@ struct test_leaky_relu : public test_case {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_SSM_CONV
|
||||||
|
struct test_ssm_conv : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR4(type, 3, 1536, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_ssm_conv(ggml_type type = GGML_TYPE_F32)
|
||||||
|
: type(type) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 3, 1536, 1);
|
||||||
|
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 1);
|
||||||
|
ggml_tensor * c = ggml_new_tensor_2d(ctx, type, 4, 1536);
|
||||||
|
ggml_tensor * sq = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 1, 1);
|
||||||
|
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c, sq);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize_tensors(ggml_context * ctx) override {
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
if (t->type == GGML_TYPE_I32) {
|
||||||
|
std::vector<int> data(1);
|
||||||
|
data[0] = 0;
|
||||||
|
ggml_backend_tensor_set(t, data.data(), 0, 1 * sizeof(int));
|
||||||
|
} else {
|
||||||
|
init_tensor_uniform(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// GGML_OP_SSM_SCAN
|
||||||
|
struct test_ssm_scan : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR4(type, 16, 1536, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_ssm_scan(ggml_type type = GGML_TYPE_F32)
|
||||||
|
: type(type) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, 16, 1536, 1);
|
||||||
|
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, 1536, 2);
|
||||||
|
ggml_tensor * dt = ggml_new_tensor_2d(ctx, type, 1536, 2);
|
||||||
|
ggml_tensor * A = ggml_new_tensor_2d(ctx, type, 16, 1536);
|
||||||
|
ggml_tensor * B = ggml_new_tensor_2d(ctx, type, 16, 2);
|
||||||
|
ggml_tensor * C = ggml_new_tensor_2d(ctx, type, 16, 2);
|
||||||
|
ggml_tensor * sq = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 1, 2);
|
||||||
|
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, sq);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize_tensors(ggml_context * ctx) override {
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
if (t->type == GGML_TYPE_I32) {
|
||||||
|
std::vector<int> data(2);
|
||||||
|
data[0] = 0;
|
||||||
|
data[1] = 0;
|
||||||
|
ggml_backend_tensor_set(t, data.data(), 0, 2 * sizeof(int));
|
||||||
|
} else {
|
||||||
|
init_tensor_uniform(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_FLASH_ATTN_EXT
|
// GGML_OP_FLASH_ATTN_EXT
|
||||||
struct test_flash_attn_ext : public test_case {
|
struct test_flash_attn_ext : public test_case {
|
||||||
const int64_t hs; // head size
|
const int64_t hs; // head size
|
||||||
|
@ -2433,6 +2503,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
test_cases.emplace_back(new test_arange());
|
test_cases.emplace_back(new test_arange());
|
||||||
test_cases.emplace_back(new test_timestep_embedding());
|
test_cases.emplace_back(new test_timestep_embedding());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
test_cases.emplace_back(new test_ssm_conv());
|
||||||
|
test_cases.emplace_back(new test_ssm_scan());
|
||||||
|
|
||||||
for (int hs : { 64, 80, 128, 256, }) {
|
for (int hs : { 64, 80, 128, 256, }) {
|
||||||
for (bool mask : { true, false } ) {
|
for (bool mask : { true, false } ) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue