ggml : remove ggml_cplan + rework ggml_cgraph
ggml-ci
This commit is contained in:
parent
ee154457dd
commit
119e0bc9ae
10 changed files with 248 additions and 175 deletions
|
@ -242,12 +242,16 @@ static bool check_gradient(
|
|||
ggml_graph_cpy(gf, gb);
|
||||
ggml_build_backward_expand(ctx0, gf, gb, false);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
|
||||
ggml_graph_prepare(gf, n_threads, nullptr);
|
||||
ggml_graph_work_init(gf, ctx0);
|
||||
ggml_graph_compute(gf);
|
||||
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
|
||||
ggml_graph_prepare(gb, n_threads, nullptr);
|
||||
ggml_graph_work_init(gb, ctx0);
|
||||
ggml_graph_compute(gb);
|
||||
|
||||
// ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot");
|
||||
// ggml_graph_dump_dot(gb, gf, "test-grad0-backward.dot");
|
||||
|
@ -262,13 +266,17 @@ static bool check_gradient(
|
|||
const float xp = x0 + eps;
|
||||
ggml_set_f32_1d(x[i], k, xp);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
|
||||
ggml_graph_prepare(gf, n_threads, nullptr);
|
||||
ggml_graph_work_init(gf, ctx0);
|
||||
ggml_graph_compute(gf);
|
||||
|
||||
const double f0 = ggml_get_f32_1d(f, 0);
|
||||
|
||||
ggml_set_f32_1d(x[i], k, xm);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, gf, n_threads);
|
||||
ggml_graph_prepare(gf, n_threads, nullptr);
|
||||
ggml_graph_work_init(gf, ctx0);
|
||||
ggml_graph_compute(gf);
|
||||
|
||||
const double f1 = ggml_get_f32_1d(f, 0);
|
||||
const double g0 = (f0 - f1)/(2.0*(double) eps);
|
||||
|
@ -301,7 +309,9 @@ static bool check_gradient(
|
|||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx0, gb, n_threads);
|
||||
ggml_graph_prepare(gb, n_threads, nullptr);
|
||||
ggml_graph_work_init(gb, ctx0);
|
||||
ggml_graph_compute(gb);
|
||||
|
||||
const double g1 = ggml_get_f32_1d(x[i]->grad, k);
|
||||
|
||||
|
|
|
@ -113,7 +113,10 @@ int main(void) {
|
|||
ggml_build_forward_expand(ge, e);
|
||||
ggml_graph_reset(ge);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
|
||||
ggml_graph_prepare(ge, 1, nullptr);
|
||||
ggml_graph_work_init(ge, nullptr);
|
||||
ggml_graph_compute(ge);
|
||||
ggml_graph_work_free(ge);
|
||||
|
||||
const float fe = ggml_get_f32_1d(e, 0);
|
||||
printf("%s: e = %.4f\n", __func__, fe);
|
||||
|
@ -124,7 +127,10 @@ int main(void) {
|
|||
|
||||
ggml_graph_reset(ge);
|
||||
|
||||
ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1);
|
||||
ggml_graph_prepare(ge, 1, nullptr);
|
||||
ggml_graph_work_init(ge, nullptr);
|
||||
ggml_graph_compute(ge);
|
||||
ggml_graph_work_free(ge);
|
||||
|
||||
const float fe_opt = ggml_get_f32_1d(e, 0);
|
||||
printf("%s: original e = %.4f\n", __func__, fe);
|
||||
|
|
|
@ -112,17 +112,6 @@ static struct ggml_tensor * get_random_tensor_f32(
|
|||
return result;
|
||||
}
|
||||
|
||||
static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
|
||||
struct ggml_cplan plan = ggml_graph_plan(graph, n_threads, nullptr);
|
||||
|
||||
if (plan.work_size > 0) {
|
||||
buf.resize(plan.work_size);
|
||||
plan.work_data = buf.data();
|
||||
}
|
||||
|
||||
ggml_graph_compute(graph, &plan);
|
||||
}
|
||||
|
||||
int main(int /*argc*/, const char ** /*argv*/) {
|
||||
struct ggml_init_params params = {
|
||||
/* .mem_size = */ 128*1024*1024,
|
||||
|
@ -130,8 +119,6 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
|||
/* .no_alloc = */ false,
|
||||
};
|
||||
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
struct ggml_tensor * x;
|
||||
|
@ -175,7 +162,10 @@ int main(int /*argc*/, const char ** /*argv*/) {
|
|||
ggml_build_forward_expand(gf, r1);
|
||||
ggml_build_forward_expand(gf, r2);
|
||||
|
||||
ggml_graph_compute_helper(work_buffer, gf, 4);
|
||||
ggml_graph_prepare(gf, 4, nullptr);
|
||||
ggml_graph_work_init(gf, nullptr);
|
||||
ggml_graph_compute(gf);
|
||||
ggml_graph_work_free(gf);
|
||||
|
||||
// check that r1 and r2 are the same
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue