add static ggml_graph_compute_sugar()
This commit is contained in:
parent
db81f33ef2
commit
2b502c32ca
1 changed files with 19 additions and 45 deletions
64
ggml.c
64
ggml.c
|
@ -16424,7 +16424,6 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
|
||||||
GGML_ASSERT(plan->n_tasks[i] > 0);
|
GGML_ASSERT(plan->n_tasks[i] > 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_threads = plan->n_threads;
|
const int n_threads = plan->n_threads;
|
||||||
|
@ -16491,6 +16490,20 @@ void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgrap
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_graph_compute_sugar(struct ggml_cgraph * cgraph, int n_threads) {
|
||||||
|
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(cgraph, n_threads);
|
||||||
|
if (plan.work_size > 0) {
|
||||||
|
plan.work_data = malloc(plan.work_size);
|
||||||
|
GGML_ASSERT(plan.work_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_graph_compute(&plan, cgraph);
|
||||||
|
|
||||||
|
if (plan.work_data) {
|
||||||
|
free(plan.work_data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
||||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
struct ggml_tensor * grad = cgraph->grads[i];
|
struct ggml_tensor * grad = cgraph->grads[i];
|
||||||
|
@ -17327,17 +17340,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||||
ggml_graph_reset (gf);
|
ggml_graph_reset (gf);
|
||||||
ggml_set_f32 (f->grad, 1.0f);
|
ggml_set_f32 (f->grad, 1.0f);
|
||||||
|
|
||||||
{
|
ggml_graph_compute_sugar(gb, params.n_threads);
|
||||||
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
|
|
||||||
if (plan.work_size > 0) {
|
|
||||||
plan.work_data = malloc(plan.work_size);
|
|
||||||
GGML_ASSERT(plan.work_data);
|
|
||||||
}
|
|
||||||
ggml_graph_compute(&plan, gb);
|
|
||||||
if (plan.work_data) {
|
|
||||||
free(plan.work_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
|
||||||
opt->adam.fx_best = opt->adam.fx_prev;
|
opt->adam.fx_best = opt->adam.fx_prev;
|
||||||
|
@ -17418,17 +17421,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||||
ggml_graph_reset (gf);
|
ggml_graph_reset (gf);
|
||||||
ggml_set_f32 (f->grad, 1.0f);
|
ggml_set_f32 (f->grad, 1.0f);
|
||||||
|
|
||||||
{
|
ggml_graph_compute_sugar(gb, params.n_threads);
|
||||||
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
|
|
||||||
if (plan.work_size > 0) {
|
|
||||||
plan.work_data = malloc(plan.work_size);
|
|
||||||
GGML_ASSERT(plan.work_data);
|
|
||||||
}
|
|
||||||
ggml_graph_compute(&plan, gb);
|
|
||||||
if (plan.work_data) {
|
|
||||||
free(plan.work_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const float fx = ggml_get_f32_1d(f, 0);
|
const float fx = ggml_get_f32_1d(f, 0);
|
||||||
|
|
||||||
|
@ -17550,17 +17543,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
||||||
ggml_graph_reset (gf);
|
ggml_graph_reset (gf);
|
||||||
ggml_set_f32 (f->grad, 1.0f);
|
ggml_set_f32 (f->grad, 1.0f);
|
||||||
|
|
||||||
{
|
ggml_graph_compute_sugar(gb, params->n_threads);
|
||||||
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params->n_threads);
|
|
||||||
if (plan.work_size > 0) {
|
|
||||||
plan.work_data = malloc(plan.work_size);
|
|
||||||
GGML_ASSERT(plan.work_data);
|
|
||||||
}
|
|
||||||
ggml_graph_compute(&plan, gb);
|
|
||||||
if (plan.work_data) {
|
|
||||||
free(plan.work_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_opt_get_grad(np, ps, g);
|
ggml_opt_get_grad(np, ps, g);
|
||||||
|
|
||||||
|
@ -17679,17 +17662,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
||||||
|
|
||||||
ggml_graph_reset (gf);
|
ggml_graph_reset (gf);
|
||||||
ggml_set_f32 (f->grad, 1.0f);
|
ggml_set_f32 (f->grad, 1.0f);
|
||||||
{
|
|
||||||
struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads);
|
ggml_graph_compute_sugar(gb, params.n_threads);
|
||||||
if (plan.work_size > 0) {
|
|
||||||
plan.work_data = malloc(plan.work_size);
|
|
||||||
GGML_ASSERT(plan.work_data);
|
|
||||||
}
|
|
||||||
ggml_graph_compute(&plan, gb);
|
|
||||||
if (plan.work_data) {
|
|
||||||
free(plan.work_data);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_opt_get_grad(np, ps, g);
|
ggml_opt_get_grad(np, ps, g);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue