From d39c8e686375b4e2dedbf98e2e11b12b1aef2526 Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 15 Jun 2023 21:07:56 +0200 Subject: [PATCH] remove unnecessary Adam(W) optimizer tensors. reduces optimizer memory overhead from 7*modelsize to 2*modelsize. additionally allows to optimize models with more than 2^31 parameters by replacing int with int64_t. bumps training checkpoint file version, but old checkpoints can still be read. new version with less tensors is saved. --- .../train-text-from-scratch.cpp | 107 ++++++++++++++--- ggml.c | 110 +++++++++--------- ggml.h | 5 - 3 files changed, 144 insertions(+), 78 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 828a2a9b7..60d2b5783 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2406,8 +2406,27 @@ void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) { file->read_raw(tensor->data, ggml_nbytes(tensor)); } +void skip_tensor(struct llama_file * file) { + int32_t nd = file->read_u32(); + + uint32_t name_len = file->read_u32(); + enum ggml_type type = (enum ggml_type) file->read_u32(); + + uint32_t ne[4] = { 1, 1, 1, 1 }; + + file->read_raw(ne, sizeof(ne[0]) * nd); + + std::string name = file->read_string(name_len); + + file->seek(-file->tell() & 31, SEEK_CUR); + + size_t nelements = ne[0]*ne[1]*ne[2]*ne[3]; + size_t nbytes = nelements*ggml_type_size(type)/ggml_blck_size(type); + file->seek(nbytes, SEEK_CUR); +} + void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) { - const uint32_t version = 0; + const uint32_t version = 1; GGML_ASSERT(opt->nx >= 0); GGML_ASSERT(opt->iter >= 0); file->write_u32(version); @@ -2418,14 +2437,10 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) switch (opt->params.type) { case GGML_OPT_ADAM: { - GGML_ASSERT(opt->adam.x != NULL); - write_tensor(file, opt->adam.x); - write_tensor(file, opt->adam.g1); - write_tensor(file, opt->adam.g2); + GGML_ASSERT(opt->adam.m != NULL); + GGML_ASSERT(opt->adam.v != NULL); write_tensor(file, opt->adam.m); write_tensor(file, opt->adam.v); - write_tensor(file, opt->adam.mh); - write_tensor(file, opt->adam.vh); write_tensor(file, opt->adam.pf); file->write_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best)); file->write_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev)); @@ -2433,7 +2448,7 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) } break; case GGML_OPT_LBFGS: { - GGML_ASSERT(opt->adam.x != NULL); + GGML_ASSERT(opt->lbfgs.x != NULL); write_tensor(file, opt->lbfgs.x); write_tensor(file, opt->lbfgs.xp); write_tensor(file, opt->lbfgs.g); @@ -2454,10 +2469,7 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) } } -void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { - uint32_t version = file->read_u32(); - GGML_ASSERT(version == 0); - +void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { file->read_raw(&opt->params, sizeof(opt->params)); file->read_raw(&opt->nx, sizeof(opt->nx)); ggml_opt_init(ctx, opt, opt->params, opt->nx); @@ -2468,13 +2480,13 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc switch (opt->params.type) { case GGML_OPT_ADAM: { - read_tensor(file, opt->adam.x); - read_tensor(file, opt->adam.g1); - read_tensor(file, opt->adam.g2); + skip_tensor(file); + skip_tensor(file); + skip_tensor(file); read_tensor(file, opt->adam.m); read_tensor(file, opt->adam.v); - read_tensor(file, opt->adam.mh); - read_tensor(file, opt->adam.vh); + skip_tensor(file); + skip_tensor(file); if (opt->adam.pf) { read_tensor(file, opt->adam.pf); } file->read_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best)); file->read_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev)); @@ -2482,7 +2494,7 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc } break; case GGML_OPT_LBFGS: { - GGML_ASSERT(opt->adam.x != NULL); + GGML_ASSERT(opt->lbfgs.x != NULL); read_tensor(file, opt->lbfgs.x); read_tensor(file, opt->lbfgs.xp); read_tensor(file, opt->lbfgs.g); @@ -2503,6 +2515,65 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc } } +void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { + file->read_raw(&opt->params, sizeof(opt->params)); + file->read_raw(&opt->nx, sizeof(opt->nx)); + ggml_opt_init(ctx, opt, opt->params, opt->nx); + + file->read_raw(&opt->iter, sizeof(opt->iter)); + opt->just_initialized = (bool) file->read_u32(); + + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + read_tensor(file, opt->adam.m); + read_tensor(file, opt->adam.v); + if (opt->adam.pf) { read_tensor(file, opt->adam.pf); } + file->read_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best)); + file->read_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev)); + file->read_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement)); + } break; + case GGML_OPT_LBFGS: + { + GGML_ASSERT(opt->lbfgs.x != NULL); + read_tensor(file, opt->lbfgs.x); + read_tensor(file, opt->lbfgs.xp); + read_tensor(file, opt->lbfgs.g); + read_tensor(file, opt->lbfgs.gp); + read_tensor(file, opt->lbfgs.d); + if (opt->lbfgs.pf) { read_tensor(file, opt->lbfgs.pf); } + read_tensor(file, opt->lbfgs.lmal); + read_tensor(file, opt->lbfgs.lmys); + read_tensor(file, opt->lbfgs.lms); + read_tensor(file, opt->lbfgs.lmy); + file->read_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best)); + file->read_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step)); + file->read_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j)); + file->read_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k)); + file->read_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end)); + file->read_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement)); + } break; + } +} + +void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { + uint32_t version = file->read_u32(); + switch (version) { + case 0: + { + read_opt_context_v0(file, ctx, opt); + } break; + case 1: + { + read_opt_context_v1(file, ctx, opt); + } break; + default: + { + fprintf(stderr, "%s: unknown version %ud\n", __func__, version); + } + } +} + void save_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename) { struct llama_file file(filename, "wb"); if (file.fp == NULL) { diff --git a/ggml.c b/ggml.c index b77f99267..143f88d4a 100644 --- a/ggml.c +++ b/ggml.c @@ -17329,7 +17329,7 @@ static enum ggml_opt_result ggml_opt_adam( struct ggml_tensor * ps[GGML_MAX_PARAMS]; int np = 0; - int nx = 0; + int64_t nx = 0; for (int i = 0; i < gf->n_nodes; ++i) { if (gf->nodes[i]->is_param) { GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op); @@ -17355,19 +17355,11 @@ static enum ggml_opt_result ggml_opt_adam( const float beta2 = params.adam.beta2; const float eps = params.adam.eps; - float * x = opt->adam.x->data; // view of the parameters - float * g1 = opt->adam.g1->data; // gradient - float * g2 = opt->adam.g2->data; // gradient squared float * m = opt->adam.m->data; // first moment float * v = opt->adam.v->data; // second moment - float * mh = opt->adam.mh->data; // first moment hat - float * vh = opt->adam.vh->data; // second moment hat float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - // update view - ggml_opt_get_params(np, ps, x); - // compute the function value ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -17412,43 +17404,61 @@ static enum ggml_opt_result ggml_opt_adam( UNUSED(t_start_cpu); { - // update the gradient - ggml_opt_get_grad(np, ps, g1); - - // m_t = beta1*m_t-1 + (1 - beta1)*g_t - ggml_vec_scale_f32(nx, m, beta1); - ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1); - - // g2 = g1^2 - ggml_vec_sqr_f32 (nx, g2, g1); - - // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2 - ggml_vec_scale_f32(nx, v, beta2); - ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2); - - // m^hat = m_t / (1 - beta1^t) - // v^hat = v_t / (1 - beta2^t) - // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1) - // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1 - // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps) - // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps) - // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay) - ggml_vec_cpy_f32 (nx, mh, m); - ggml_vec_cpy_f32 (nx, vh, v); - - ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter))); - ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter))); - - ggml_vec_sqrt_f32 (nx, vh, vh); - ggml_vec_acc1_f32 (nx, vh, eps); - - ggml_vec_div_f32 (nx, mh, mh, vh); - ggml_vec_scale_f32(nx, x, 1.0f - decay); - ggml_vec_sub_f32 (nx, x, x, mh); - - // update the parameters - ggml_opt_set_params(np, ps, x); + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + for (int64_t j = 0; j < ne; ++j) { + float x = ggml_get_f32_1d(ps[p], j); + float g = ggml_get_f32_1d(ps[p]->grad, j); + m[i] = m[i]*beta1 + g*(1.0f - beta1); + v[i] = v[i]*beta2 + g*g*(1.0f - beta2); + float mh = m[i]*alpha/(1.0f - powf(beta1, opt->iter)); + float vh = v[i]*1.0f /(1.0f - powf(beta2, opt->iter)); + vh = sqrtf(vh) + eps; + x = x*(1.0f - decay) - mh/vh; + ggml_set_f32_1d(ps[p], j, x); + ++i; + } + } } + // { + // // update the gradient + // ggml_opt_get_grad(np, ps, g1); + + // // m_t = beta1*m_t-1 + (1 - beta1)*g_t + // ggml_vec_scale_f32(nx, m, beta1); + // ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1); + + // // g2 = g1^2 + // ggml_vec_sqr_f32 (nx, g2, g1); + + // // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2 + // ggml_vec_scale_f32(nx, v, beta2); + // ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2); + + // // m^hat = m_t / (1 - beta1^t) + // // v^hat = v_t / (1 - beta2^t) + // // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1) + // // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1 + // // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps) + // // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps) + // // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay) + // ggml_vec_cpy_f32 (nx, mh, m); + // ggml_vec_cpy_f32 (nx, vh, v); + + // ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter))); + // ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter))); + + // ggml_vec_sqrt_f32 (nx, vh, vh); + // ggml_vec_acc1_f32 (nx, vh, eps); + + // ggml_vec_div_f32 (nx, mh, mh, vh); + // ggml_vec_scale_f32(nx, x, 1.0f - decay); + // ggml_vec_sub_f32 (nx, x, x, mh); + + // // update the parameters + // ggml_opt_set_params(np, ps, x); + // } ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -17941,23 +17951,13 @@ GGML_API void ggml_opt_init( switch (opt->params.type) { case GGML_OPT_ADAM: { - opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); - opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->adam.pf = params.past > 0 ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) : NULL; - ggml_set_zero(opt->adam.x); - ggml_set_zero(opt->adam.g1); - ggml_set_zero(opt->adam.g2); ggml_set_zero(opt->adam.m); ggml_set_zero(opt->adam.v); - ggml_set_zero(opt->adam.mh); - ggml_set_zero(opt->adam.vh); if (opt->adam.pf) { ggml_set_zero(opt->adam.pf); } diff --git a/ggml.h b/ggml.h index 9919cce7c..531b6cb07 100644 --- a/ggml.h +++ b/ggml.h @@ -1537,13 +1537,8 @@ extern "C" { bool just_initialized; struct { - struct ggml_tensor * x; // view of the parameters - struct ggml_tensor * g1; // gradient - struct ggml_tensor * g2; // gradient squared struct ggml_tensor * m; // first moment struct ggml_tensor * v; // second moment - struct ggml_tensor * mh; // first moment hat - struct ggml_tensor * vh; // second moment hat struct ggml_tensor * pf; // past function values float fx_best; float fx_prev;