ggml : add ggml_graph_compute_with_ctx()

- backwards compatible API
- deduplicates a lot of copy-paste
This commit is contained in:
Georgi Gerganov 2023-07-06 20:43:43 +03:00
parent 8e1f0b6865
commit 2392f7a9cd
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 29 additions and 125 deletions

32
ggml.c
View file

@ -16493,21 +16493,17 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
} }
} }
// TODO: avoid allocating memory frequently. // same as ggml_graph_compute() but the work data is allocated as a part of the context
// TODO: make part of public API - use different name and put warning that it makes allocations // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
static void ggml_graph_compute_helper(struct ggml_cgraph * cgraph, int n_threads) { void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads); struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
if (cplan.work_size > 0) { struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
cplan.work_data = malloc(cplan.work_size); GGML_ASSERT(buf);
GGML_ASSERT(cplan.work_data);
} cplan.work_data = buf->data;
ggml_graph_compute(cgraph, &cplan); ggml_graph_compute(cgraph, &cplan);
if (cplan.work_data) {
free(cplan.work_data);
}
} }
void ggml_graph_reset(struct ggml_cgraph * cgraph) { void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@ -17292,6 +17288,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
// //
static enum ggml_opt_result ggml_opt_adam( static enum ggml_opt_result ggml_opt_adam(
struct ggml_context * ctx,
struct ggml_opt_context * opt, struct ggml_opt_context * opt,
struct ggml_opt_params params, struct ggml_opt_params params,
struct ggml_tensor * f, struct ggml_tensor * f,
@ -17346,7 +17343,7 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_graph_reset (gf); ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_helper(gb, params.n_threads); ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
opt->adam.fx_prev = ggml_get_f32_1d(f, 0); opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
opt->adam.fx_best = opt->adam.fx_prev; opt->adam.fx_best = opt->adam.fx_prev;
@ -17427,7 +17424,7 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_graph_reset (gf); ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_helper(gb, params.n_threads); ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
const float fx = ggml_get_f32_1d(f, 0); const float fx = ggml_get_f32_1d(f, 0);
@ -17498,6 +17495,7 @@ struct ggml_lbfgs_iteration_data {
}; };
static enum ggml_opt_result linesearch_backtracking( static enum ggml_opt_result linesearch_backtracking(
struct ggml_context * ctx,
const struct ggml_opt_params * params, const struct ggml_opt_params * params,
int nx, int nx,
float * x, float * x,
@ -17549,7 +17547,7 @@ static enum ggml_opt_result linesearch_backtracking(
ggml_graph_reset (gf); ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_helper(gb, params->n_threads); ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
ggml_opt_get_grad(np, ps, g); ggml_opt_get_grad(np, ps, g);
@ -17669,7 +17667,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_graph_reset (gf); ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_helper(gb, params.n_threads); ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
ggml_opt_get_grad(np, ps, g); ggml_opt_get_grad(np, ps, g);
@ -17728,7 +17726,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, xp, x);
ggml_vec_cpy_f32(nx, gp, g); ggml_vec_cpy_f32(nx, gp, g);
ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps); ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
if (ls < 0) { if (ls < 0) {
// linesearch failed - go back to the previous point and return // linesearch failed - go back to the previous point and return
@ -18030,7 +18028,7 @@ enum ggml_opt_result ggml_opt_resume_g(
switch (opt->params.type) { switch (opt->params.type) {
case GGML_OPT_ADAM: case GGML_OPT_ADAM:
{ {
result = ggml_opt_adam(opt, opt->params, f, gf, gb); result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
} break; } break;
case GGML_OPT_LBFGS: case GGML_OPT_LBFGS:
{ {

4
ggml.h
View file

@ -1319,6 +1319,10 @@ extern "C" {
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
// same as ggml_graph_compute() but the work data is allocated as a part of the context
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);

View file

@ -195,32 +195,6 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
} }
struct work_buffer {
size_t size;
uint8_t * data;
};
static uint8_t * work_buffer_resize(struct work_buffer * buf, size_t size) {
if (size == 0) {
return NULL;
}
GGML_ASSERT(buf);
if (buf->size == 0) {
buf->data = malloc(size);
buf->size = size;
} else if (buf->size < size) {
buf->data = realloc(buf->data, size);
buf->size = size;
} else {
// skip shrinking.
}
GGML_ASSERT(buf->data);
return buf->data;
}
bool check_gradient( bool check_gradient(
const char * op_name, const char * op_name,
struct ggml_context * ctx0, struct ggml_context * ctx0,
@ -247,28 +221,12 @@ bool check_gradient(
struct ggml_cgraph gf = ggml_build_forward (f); struct ggml_cgraph gf = ggml_build_forward (f);
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false); struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
struct work_buffer buf = { /*.size = */ 0, /*.data =*/ NULL }; ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
{
struct ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
if (pf.work_size > 0) {
pf.work_data = malloc(pf.work_size);
GGML_ASSERT(pf.work_data);
}
ggml_graph_compute(&gf, &pf);
if (pf.work_data) {
free(pf.work_data);
}
}
ggml_graph_reset (&gf); ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
{ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
struct ggml_cplan pf = ggml_graph_plan(&gb, n_threads);
pf.work_data = work_buffer_resize(&buf, pf.work_size);
ggml_graph_compute(&gf, &pf);
}
// ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot"); // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
// ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot"); // ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
@ -282,24 +240,15 @@ bool check_gradient(
const float xp = x0 + eps; const float xp = x0 + eps;
set_element(x[i], k, xp); set_element(x[i], k, xp);
{ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
struct ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
pf.work_data = work_buffer_resize(&buf, pf.work_size);
ggml_graph_compute(&gf, &pf);
}
const float f0 = ggml_get_f32_1d(f, 0); const float f0 = ggml_get_f32_1d(f, 0);
set_element(x[i], k, xm); set_element(x[i], k, xm);
{ ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
struct ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
pf.work_data = work_buffer_resize(&buf, pf.work_size);
ggml_graph_compute(&gf, &pf);
}
const float f1 = ggml_get_f32_1d(f, 0); const float f1 = ggml_get_f32_1d(f, 0);
const float g0 = (f0 - f1)/(2.0f*eps); const float g0 = (f0 - f1)/(2.0f*eps);
set_element(x[i], k, x0); set_element(x[i], k, x0);
@ -308,11 +257,7 @@ bool check_gradient(
ggml_graph_reset (&gf); ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
{ ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
struct ggml_cplan pf = ggml_graph_plan(&gb, n_threads);
pf.work_data = work_buffer_resize(&buf, pf.work_size);
ggml_graph_compute(&gf, &pf);
}
const float g1 = get_element(x[i]->grad, k); const float g1 = get_element(x[i]->grad, k);
@ -328,10 +273,6 @@ bool check_gradient(
} }
} }
if (buf.data) {
free(buf.data);
}
return true; return true;
} }

View file

@ -115,31 +115,6 @@ void set_element(struct ggml_tensor * t, int idx, float value) {
((float *)t->data)[idx] = value; ((float *)t->data)[idx] = value;
} }
struct work_buffer {
size_t size;
uint8_t * data;
};
static uint8_t * work_buffer_resize(struct work_buffer * buf, size_t size) {
if (size == 0) {
return NULL;
}
if (buf->size == 0) {
buf->data = malloc(size);
buf->size = size;
} else if (buf->size < size) {
buf->data = realloc(buf->data, size);
buf->size = size;
} else {
// skip shrinking.
}
GGML_ASSERT(buf->data);
return buf->data;
}
int main(void) { int main(void) {
struct ggml_init_params params = { struct ggml_init_params params = {
.mem_size = 1024*1024*1024, .mem_size = 1024*1024*1024,
@ -163,16 +138,10 @@ int main(void) {
struct ggml_tensor * d = ggml_sub(ctx, c, ab); struct ggml_tensor * d = ggml_sub(ctx, c, ab);
struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d)); struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d));
struct ggml_cgraph ge = ggml_build_forward(e); struct ggml_cgraph ge = ggml_build_forward(e);
ggml_graph_reset (&ge); ggml_graph_reset(&ge);
struct work_buffer buf = { /*.size = */ 0, /*.data =*/ NULL }; ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
{
struct ggml_cplan pe = ggml_graph_plan(&ge, /*n_threads*/ 1);
pe.work_data = work_buffer_resize(&buf, pe.work_size);
ggml_graph_compute(&ge, &pe);
}
const float fe = ggml_get_f32_1d(e, 0); const float fe = ggml_get_f32_1d(e, 0);
printf("%s: e = %.4f\n", __func__, fe); printf("%s: e = %.4f\n", __func__, fe);
@ -181,17 +150,9 @@ int main(void) {
ggml_opt(ctx, opt_params, e); ggml_opt(ctx, opt_params, e);
ggml_graph_reset (&ge); ggml_graph_reset(&ge);
{ ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
struct ggml_cplan pe = ggml_graph_plan(&ge, /*n_threads*/ 1);
pe.work_data = work_buffer_resize(&buf, pe.work_size);
ggml_graph_compute(&ge, &pe);
}
if (buf.data) {
free(buf.data);
}
const float fe_opt = ggml_get_f32_1d(e, 0); const float fe_opt = ggml_get_f32_1d(e, 0);
printf("%s: original e = %.4f\n", __func__, fe); printf("%s: original e = %.4f\n", __func__, fe);