remove unnecessary Adam(W) optimizer tensors.

reduces optimizer memory overhead from 7*modelsize to 2*modelsize.

additionally allows to optimize models with more than 2^31 parameters by replacing int with int64_t.

bumps training checkpoint file version, but old checkpoints can still be read.
new version with less tensors is saved.
This commit is contained in:
xaedes 2023-06-15 21:07:56 +02:00
parent 5d124d0cb4
commit d39c8e6863
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
3 changed files with 144 additions and 78 deletions

View file

@ -2406,8 +2406,27 @@ void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
file->read_raw(tensor->data, ggml_nbytes(tensor));
}
void skip_tensor(struct llama_file * file) {
int32_t nd = file->read_u32();
uint32_t name_len = file->read_u32();
enum ggml_type type = (enum ggml_type) file->read_u32();
uint32_t ne[4] = { 1, 1, 1, 1 };
file->read_raw(ne, sizeof(ne[0]) * nd);
std::string name = file->read_string(name_len);
file->seek(-file->tell() & 31, SEEK_CUR);
size_t nelements = ne[0]*ne[1]*ne[2]*ne[3];
size_t nbytes = nelements*ggml_type_size(type)/ggml_blck_size(type);
file->seek(nbytes, SEEK_CUR);
}
void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) {
const uint32_t version = 0;
const uint32_t version = 1;
GGML_ASSERT(opt->nx >= 0);
GGML_ASSERT(opt->iter >= 0);
file->write_u32(version);
@ -2418,14 +2437,10 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt)
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
GGML_ASSERT(opt->adam.x != NULL);
write_tensor(file, opt->adam.x);
write_tensor(file, opt->adam.g1);
write_tensor(file, opt->adam.g2);
GGML_ASSERT(opt->adam.m != NULL);
GGML_ASSERT(opt->adam.v != NULL);
write_tensor(file, opt->adam.m);
write_tensor(file, opt->adam.v);
write_tensor(file, opt->adam.mh);
write_tensor(file, opt->adam.vh);
write_tensor(file, opt->adam.pf);
file->write_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best));
file->write_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev));
@ -2433,7 +2448,7 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt)
} break;
case GGML_OPT_LBFGS:
{
GGML_ASSERT(opt->adam.x != NULL);
GGML_ASSERT(opt->lbfgs.x != NULL);
write_tensor(file, opt->lbfgs.x);
write_tensor(file, opt->lbfgs.xp);
write_tensor(file, opt->lbfgs.g);
@ -2454,10 +2469,7 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt)
}
}
void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
uint32_t version = file->read_u32();
GGML_ASSERT(version == 0);
void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
file->read_raw(&opt->params, sizeof(opt->params));
file->read_raw(&opt->nx, sizeof(opt->nx));
ggml_opt_init(ctx, opt, opt->params, opt->nx);
@ -2468,13 +2480,13 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
read_tensor(file, opt->adam.x);
read_tensor(file, opt->adam.g1);
read_tensor(file, opt->adam.g2);
skip_tensor(file);
skip_tensor(file);
skip_tensor(file);
read_tensor(file, opt->adam.m);
read_tensor(file, opt->adam.v);
read_tensor(file, opt->adam.mh);
read_tensor(file, opt->adam.vh);
skip_tensor(file);
skip_tensor(file);
if (opt->adam.pf) { read_tensor(file, opt->adam.pf); }
file->read_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best));
file->read_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev));
@ -2482,7 +2494,7 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc
} break;
case GGML_OPT_LBFGS:
{
GGML_ASSERT(opt->adam.x != NULL);
GGML_ASSERT(opt->lbfgs.x != NULL);
read_tensor(file, opt->lbfgs.x);
read_tensor(file, opt->lbfgs.xp);
read_tensor(file, opt->lbfgs.g);
@ -2503,6 +2515,65 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc
}
}
void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
file->read_raw(&opt->params, sizeof(opt->params));
file->read_raw(&opt->nx, sizeof(opt->nx));
ggml_opt_init(ctx, opt, opt->params, opt->nx);
file->read_raw(&opt->iter, sizeof(opt->iter));
opt->just_initialized = (bool) file->read_u32();
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
read_tensor(file, opt->adam.m);
read_tensor(file, opt->adam.v);
if (opt->adam.pf) { read_tensor(file, opt->adam.pf); }
file->read_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best));
file->read_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev));
file->read_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement));
} break;
case GGML_OPT_LBFGS:
{
GGML_ASSERT(opt->lbfgs.x != NULL);
read_tensor(file, opt->lbfgs.x);
read_tensor(file, opt->lbfgs.xp);
read_tensor(file, opt->lbfgs.g);
read_tensor(file, opt->lbfgs.gp);
read_tensor(file, opt->lbfgs.d);
if (opt->lbfgs.pf) { read_tensor(file, opt->lbfgs.pf); }
read_tensor(file, opt->lbfgs.lmal);
read_tensor(file, opt->lbfgs.lmys);
read_tensor(file, opt->lbfgs.lms);
read_tensor(file, opt->lbfgs.lmy);
file->read_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best));
file->read_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step));
file->read_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j));
file->read_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k));
file->read_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end));
file->read_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement));
} break;
}
}
void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
uint32_t version = file->read_u32();
switch (version) {
case 0:
{
read_opt_context_v0(file, ctx, opt);
} break;
case 1:
{
read_opt_context_v1(file, ctx, opt);
} break;
default:
{
fprintf(stderr, "%s: unknown version %ud\n", __func__, version);
}
}
}
void save_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename) {
struct llama_file file(filename, "wb");
if (file.fp == NULL) {

110
ggml.c
View file

@ -17329,7 +17329,7 @@ static enum ggml_opt_result ggml_opt_adam(
struct ggml_tensor * ps[GGML_MAX_PARAMS];
int np = 0;
int nx = 0;
int64_t nx = 0;
for (int i = 0; i < gf->n_nodes; ++i) {
if (gf->nodes[i]->is_param) {
GGML_PRINT_DEBUG("found param %d: grad->op = %d\n", np, gf->nodes[i]->grad->op);
@ -17355,19 +17355,11 @@ static enum ggml_opt_result ggml_opt_adam(
const float beta2 = params.adam.beta2;
const float eps = params.adam.eps;
float * x = opt->adam.x->data; // view of the parameters
float * g1 = opt->adam.g1->data; // gradient
float * g2 = opt->adam.g2->data; // gradient squared
float * m = opt->adam.m->data; // first moment
float * v = opt->adam.v->data; // second moment
float * mh = opt->adam.mh->data; // first moment hat
float * vh = opt->adam.vh->data; // second moment hat
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
// update view
ggml_opt_get_params(np, ps, x);
// compute the function value
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
@ -17412,43 +17404,61 @@ static enum ggml_opt_result ggml_opt_adam(
UNUSED(t_start_cpu);
{
// update the gradient
ggml_opt_get_grad(np, ps, g1);
// m_t = beta1*m_t-1 + (1 - beta1)*g_t
ggml_vec_scale_f32(nx, m, beta1);
ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1);
// g2 = g1^2
ggml_vec_sqr_f32 (nx, g2, g1);
// v_t = beta2*v_t-1 + (1 - beta2)*g_t^2
ggml_vec_scale_f32(nx, v, beta2);
ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2);
// m^hat = m_t / (1 - beta1^t)
// v^hat = v_t / (1 - beta2^t)
// x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
// x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
// x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
// x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
// x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
ggml_vec_cpy_f32 (nx, mh, m);
ggml_vec_cpy_f32 (nx, vh, v);
ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
ggml_vec_sqrt_f32 (nx, vh, vh);
ggml_vec_acc1_f32 (nx, vh, eps);
ggml_vec_div_f32 (nx, mh, mh, vh);
ggml_vec_scale_f32(nx, x, 1.0f - decay);
ggml_vec_sub_f32 (nx, x, x, mh);
// update the parameters
ggml_opt_set_params(np, ps, x);
int64_t i = 0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_nelements(ps[p]) ;
for (int64_t j = 0; j < ne; ++j) {
float x = ggml_get_f32_1d(ps[p], j);
float g = ggml_get_f32_1d(ps[p]->grad, j);
m[i] = m[i]*beta1 + g*(1.0f - beta1);
v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
float mh = m[i]*alpha/(1.0f - powf(beta1, opt->iter));
float vh = v[i]*1.0f /(1.0f - powf(beta2, opt->iter));
vh = sqrtf(vh) + eps;
x = x*(1.0f - decay) - mh/vh;
ggml_set_f32_1d(ps[p], j, x);
++i;
}
}
}
// {
// // update the gradient
// ggml_opt_get_grad(np, ps, g1);
// // m_t = beta1*m_t-1 + (1 - beta1)*g_t
// ggml_vec_scale_f32(nx, m, beta1);
// ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1);
// // g2 = g1^2
// ggml_vec_sqr_f32 (nx, g2, g1);
// // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2
// ggml_vec_scale_f32(nx, v, beta2);
// ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2);
// // m^hat = m_t / (1 - beta1^t)
// // v^hat = v_t / (1 - beta2^t)
// // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
// // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
// // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
// // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
// // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
// ggml_vec_cpy_f32 (nx, mh, m);
// ggml_vec_cpy_f32 (nx, vh, v);
// ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
// ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
// ggml_vec_sqrt_f32 (nx, vh, vh);
// ggml_vec_acc1_f32 (nx, vh, eps);
// ggml_vec_div_f32 (nx, mh, mh, vh);
// ggml_vec_scale_f32(nx, x, 1.0f - decay);
// ggml_vec_sub_f32 (nx, x, x, mh);
// // update the parameters
// ggml_opt_set_params(np, ps, x);
// }
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
@ -17941,23 +17951,13 @@ GGML_API void ggml_opt_init(
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
opt->adam.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->adam.g1 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->adam.g2 = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->adam.mh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->adam.vh = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->adam.pf = params.past > 0
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
: NULL;
ggml_set_zero(opt->adam.x);
ggml_set_zero(opt->adam.g1);
ggml_set_zero(opt->adam.g2);
ggml_set_zero(opt->adam.m);
ggml_set_zero(opt->adam.v);
ggml_set_zero(opt->adam.mh);
ggml_set_zero(opt->adam.vh);
if (opt->adam.pf) {
ggml_set_zero(opt->adam.pf);
}

5
ggml.h
View file

@ -1537,13 +1537,8 @@ extern "C" {
bool just_initialized;
struct {
struct ggml_tensor * x; // view of the parameters
struct ggml_tensor * g1; // gradient
struct ggml_tensor * g2; // gradient squared
struct ggml_tensor * m; // first moment
struct ggml_tensor * v; // second moment
struct ggml_tensor * mh; // first moment hat
struct ggml_tensor * vh; // second moment hat
struct ggml_tensor * pf; // past function values
float fx_best;
float fx_prev;