ggml : add ggml_graph_compute_with_ctx()

- backwards compatible API
- deduplicates a lot of copy-paste
This commit is contained in:
Georgi Gerganov 2023-07-06 20:43:43 +03:00
parent 8e1f0b6865
commit 2392f7a9cd
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 29 additions and 125 deletions

32
ggml.c
View file

@ -16493,21 +16493,17 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
}
}
// TODO: avoid allocating memory frequently.
// TODO: make part of public API - use different name and put warning that it makes allocations
static void ggml_graph_compute_helper(struct ggml_cgraph * cgraph, int n_threads) {
// same as ggml_graph_compute() but the work data is allocated as a part of the context
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
if (cplan.work_size > 0) {
cplan.work_data = malloc(cplan.work_size);
GGML_ASSERT(cplan.work_data);
}
struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
GGML_ASSERT(buf);
cplan.work_data = buf->data;
ggml_graph_compute(cgraph, &cplan);
if (cplan.work_data) {
free(cplan.work_data);
}
}
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@ -17292,6 +17288,7 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g
//
static enum ggml_opt_result ggml_opt_adam(
struct ggml_context * ctx,
struct ggml_opt_context * opt,
struct ggml_opt_params params,
struct ggml_tensor * f,
@ -17346,7 +17343,7 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_helper(gb, params.n_threads);
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
opt->adam.fx_best = opt->adam.fx_prev;
@ -17427,7 +17424,7 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_helper(gb, params.n_threads);
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
const float fx = ggml_get_f32_1d(f, 0);
@ -17498,6 +17495,7 @@ struct ggml_lbfgs_iteration_data {
};
static enum ggml_opt_result linesearch_backtracking(
struct ggml_context * ctx,
const struct ggml_opt_params * params,
int nx,
float * x,
@ -17549,7 +17547,7 @@ static enum ggml_opt_result linesearch_backtracking(
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_helper(gb, params->n_threads);
ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
ggml_opt_get_grad(np, ps, g);
@ -17669,7 +17667,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute_helper(gb, params.n_threads);
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
ggml_opt_get_grad(np, ps, g);
@ -17728,7 +17726,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_vec_cpy_f32(nx, xp, x);
ggml_vec_cpy_f32(nx, gp, g);
ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
if (ls < 0) {
// linesearch failed - go back to the previous point and return
@ -18030,7 +18028,7 @@ enum ggml_opt_result ggml_opt_resume_g(
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
result = ggml_opt_adam(opt, opt->params, f, gf, gb);
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
} break;
case GGML_OPT_LBFGS:
{

6
ggml.h
View file

@ -1306,7 +1306,7 @@ extern "C" {
GGML_API void ggml_set_param(
struct ggml_context * ctx,
struct ggml_tensor * tensor);
struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
@ -1319,6 +1319,10 @@ extern "C" {
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
// same as ggml_graph_compute() but the work data is allocated as a part of the context
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);

View file

@ -195,32 +195,6 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
}
struct work_buffer {
size_t size;
uint8_t * data;
};
static uint8_t * work_buffer_resize(struct work_buffer * buf, size_t size) {
if (size == 0) {
return NULL;
}
GGML_ASSERT(buf);
if (buf->size == 0) {
buf->data = malloc(size);
buf->size = size;
} else if (buf->size < size) {
buf->data = realloc(buf->data, size);
buf->size = size;
} else {
// skip shrinking.
}
GGML_ASSERT(buf->data);
return buf->data;
}
bool check_gradient(
const char * op_name,
struct ggml_context * ctx0,
@ -247,28 +221,12 @@ bool check_gradient(
struct ggml_cgraph gf = ggml_build_forward (f);
struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false);
struct work_buffer buf = { /*.size = */ 0, /*.data =*/ NULL };
{
struct ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
if (pf.work_size > 0) {
pf.work_data = malloc(pf.work_size);
GGML_ASSERT(pf.work_data);
}
ggml_graph_compute(&gf, &pf);
if (pf.work_data) {
free(pf.work_data);
}
}
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f);
{
struct ggml_cplan pf = ggml_graph_plan(&gb, n_threads);
pf.work_data = work_buffer_resize(&buf, pf.work_size);
ggml_graph_compute(&gf, &pf);
}
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
// ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot");
// ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot");
@ -282,24 +240,15 @@ bool check_gradient(
const float xp = x0 + eps;
set_element(x[i], k, xp);
{
struct ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
pf.work_data = work_buffer_resize(&buf, pf.work_size);
ggml_graph_compute(&gf, &pf);
}
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f0 = ggml_get_f32_1d(f, 0);
set_element(x[i], k, xm);
{
struct ggml_cplan pf = ggml_graph_plan(&gf, n_threads);
pf.work_data = work_buffer_resize(&buf, pf.work_size);
ggml_graph_compute(&gf, &pf);
}
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f1 = ggml_get_f32_1d(f, 0);
const float g0 = (f0 - f1)/(2.0f*eps);
set_element(x[i], k, x0);
@ -308,11 +257,7 @@ bool check_gradient(
ggml_graph_reset (&gf);
ggml_set_f32 (f->grad, 1.0f);
{
struct ggml_cplan pf = ggml_graph_plan(&gb, n_threads);
pf.work_data = work_buffer_resize(&buf, pf.work_size);
ggml_graph_compute(&gf, &pf);
}
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
const float g1 = get_element(x[i]->grad, k);
@ -328,10 +273,6 @@ bool check_gradient(
}
}
if (buf.data) {
free(buf.data);
}
return true;
}

View file

@ -115,31 +115,6 @@ void set_element(struct ggml_tensor * t, int idx, float value) {
((float *)t->data)[idx] = value;
}
struct work_buffer {
size_t size;
uint8_t * data;
};
static uint8_t * work_buffer_resize(struct work_buffer * buf, size_t size) {
if (size == 0) {
return NULL;
}
if (buf->size == 0) {
buf->data = malloc(size);
buf->size = size;
} else if (buf->size < size) {
buf->data = realloc(buf->data, size);
buf->size = size;
} else {
// skip shrinking.
}
GGML_ASSERT(buf->data);
return buf->data;
}
int main(void) {
struct ggml_init_params params = {
.mem_size = 1024*1024*1024,
@ -163,16 +138,10 @@ int main(void) {
struct ggml_tensor * d = ggml_sub(ctx, c, ab);
struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d));
struct ggml_cgraph ge = ggml_build_forward(e);
ggml_graph_reset (&ge);
ggml_graph_reset(&ge);
struct work_buffer buf = { /*.size = */ 0, /*.data =*/ NULL };
{
struct ggml_cplan pe = ggml_graph_plan(&ge, /*n_threads*/ 1);
pe.work_data = work_buffer_resize(&buf, pe.work_size);
ggml_graph_compute(&ge, &pe);
}
ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
const float fe = ggml_get_f32_1d(e, 0);
printf("%s: e = %.4f\n", __func__, fe);
@ -181,17 +150,9 @@ int main(void) {
ggml_opt(ctx, opt_params, e);
ggml_graph_reset (&ge);
ggml_graph_reset(&ge);
{
struct ggml_cplan pe = ggml_graph_plan(&ge, /*n_threads*/ 1);
pe.work_data = work_buffer_resize(&buf, pe.work_size);
ggml_graph_compute(&ge, &pe);
}
if (buf.data) {
free(buf.data);
}
ggml_graph_compute_with_ctx(ctx, &ge, /*n_threads*/ 1);
const float fe_opt = ggml_get_f32_1d(e, 0);
printf("%s: original e = %.4f\n", __func__, fe);