add gradient clipping to AdamW

This commit is contained in:
xaedes 2023-06-15 23:48:46 +02:00
parent d39c8e6863
commit d395b19c8c
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
3 changed files with 98 additions and 9 deletions

View file

@ -2430,7 +2430,8 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt)
GGML_ASSERT(opt->nx >= 0); GGML_ASSERT(opt->nx >= 0);
GGML_ASSERT(opt->iter >= 0); GGML_ASSERT(opt->iter >= 0);
file->write_u32(version); file->write_u32(version);
file->write_raw(&opt->params, sizeof(opt->params)); file->write_u32(opt->params.past);
file->write_u32(opt->params.lbfgs.m);
file->write_raw(&opt->nx, sizeof(opt->nx)); file->write_raw(&opt->nx, sizeof(opt->nx));
file->write_raw(&opt->iter, sizeof(opt->iter)); file->write_raw(&opt->iter, sizeof(opt->iter));
file->write_u32((uint32_t) opt->just_initialized); file->write_u32((uint32_t) opt->just_initialized);
@ -2469,9 +2470,44 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt)
} }
} }
struct ggml_opt_params_v0 {
enum ggml_opt_type type;
int n_threads;
int past;
float delta;
int max_no_improvement;
bool print_forward_graph;
bool print_backward_graph;
struct {
int n_iter;
float sched;
float decay;
float alpha;
float beta1;
float beta2;
float eps;
float eps_f;
float eps_g;
} adam;
struct {
int m;
int n_iter;
int max_linesearch;
float eps;
float ftol;
float wolfe;
float min_step;
float max_step;
enum ggml_linesearch linesearch;
} lbfgs;
};
void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
file->read_raw(&opt->params, sizeof(opt->params)); ggml_opt_params_v0 pv0;
file->read_raw(&opt->nx, sizeof(opt->nx)); file->read_raw(&pv0, sizeof(pv0));
opt->params.past = pv0.past;
opt->params.lbfgs.m = pv0.lbfgs.m;
file->read_raw(&opt->nx, sizeof(opt->nx));
ggml_opt_init(ctx, opt, opt->params, opt->nx); ggml_opt_init(ctx, opt, opt->params, opt->nx);
file->read_raw(&opt->iter, sizeof(opt->iter)); file->read_raw(&opt->iter, sizeof(opt->iter));
@ -2516,7 +2552,8 @@ void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, st
} }
void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
file->read_raw(&opt->params, sizeof(opt->params)); opt->params.past = (int) file->read_u32();
opt->params.lbfgs.m = (int) file->read_u32();
file->read_raw(&opt->nx, sizeof(opt->nx)); file->read_raw(&opt->nx, sizeof(opt->nx));
ggml_opt_init(ctx, opt, opt->params, opt->nx); ggml_opt_init(ctx, opt, opt->params, opt->nx);
@ -2558,6 +2595,7 @@ void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, st
void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
uint32_t version = file->read_u32(); uint32_t version = file->read_u32();
printf("%s: opt context version %u\n", __func__, version);
switch (version) { switch (version) {
case 0: case 0:
{ {
@ -2569,7 +2607,7 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc
} break; } break;
default: default:
{ {
fprintf(stderr, "%s: unknown version %ud\n", __func__, version); fprintf(stderr, "%s: unknown version %u\n", __func__, version);
} }
} }
} }
@ -2783,6 +2821,9 @@ struct train_params {
int adam_n_iter; int adam_n_iter;
float adam_alpha; float adam_alpha;
float adam_decay; float adam_decay;
float adam_beta1;
float adam_beta2;
float adam_gclip;
int mem_model_gb; int mem_model_gb;
int mem_compute_gb; int mem_compute_gb;
@ -2830,6 +2871,9 @@ struct train_params get_default_train_params() {
params.adam_n_iter = 16; params.adam_n_iter = 16;
params.adam_alpha = 1e-3f; params.adam_alpha = 1e-3f;
params.adam_decay = 1e-3f; params.adam_decay = 1e-3f;
params.adam_beta1 = 0.9f;
params.adam_beta2 = 0.999f;
params.adam_gclip = 1.0f;
params.mem_model_gb = 2; params.mem_model_gb = 2;
params.mem_compute_gb = 24; params.mem_compute_gb = 24;
@ -2877,6 +2921,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb); fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb); fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb); fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
@ -3066,6 +3113,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break; break;
} }
params->adam_decay = std::stof(argv[i]); params->adam_decay = std::stof(argv[i]);
} else if (arg == "--adam-beta1") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta1 = std::stof(argv[i]);
} else if (arg == "--adam-beta2") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_beta2 = std::stof(argv[i]);
} else if (arg == "--adam-gclip") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "--mem-model") { } else if (arg == "--mem-model") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -3212,6 +3277,9 @@ int main(int argc, char ** argv) {
opt_params_adam.adam.sched = 1.0f; opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = params.adam_alpha; opt_params_adam.adam.alpha = params.adam_alpha;
opt_params_adam.adam.decay = params.adam_decay; opt_params_adam.adam.decay = params.adam_decay;
opt_params_adam.adam.beta1 = params.adam_beta1;
opt_params_adam.adam.beta2 = params.adam_beta2;
opt_params_adam.adam.gclip = params.adam_gclip;
opt_params_lbfgs.print_forward_graph = false; opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false; opt_params_lbfgs.print_backward_graph = false;

28
ggml.c
View file

@ -17354,6 +17354,7 @@ static enum ggml_opt_result ggml_opt_adam(
const float beta1 = params.adam.beta1; const float beta1 = params.adam.beta1;
const float beta2 = params.adam.beta2; const float beta2 = params.adam.beta2;
const float eps = params.adam.eps; const float eps = params.adam.eps;
const float gclip = params.adam.gclip;
float * m = opt->adam.m->data; // first moment float * m = opt->adam.m->data; // first moment
float * v = opt->adam.v->data; // second moment float * v = opt->adam.v->data; // second moment
@ -17404,16 +17405,34 @@ static enum ggml_opt_result ggml_opt_adam(
UNUSED(t_start_cpu); UNUSED(t_start_cpu);
{ {
float gnorm = 1.0f;
if (gclip > 0.0f) {
// gradient clipping
ggml_float sum = 0.0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_nelements(ps[p]);
for (int64_t j = 0; j < ne; ++j) {
float g = ggml_get_f32_1d(ps[p]->grad, j);
sum += g*g;
}
}
ggml_float norm = sqrt(sum);
if (norm > (ggml_float) gclip) {
gnorm = (float) ((ggml_float) gclip / norm);
}
}
const float beta1h = alpha/(1.0f - powf(beta1, opt->iter));
const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter));
int64_t i = 0; int64_t i = 0;
for (int p = 0; p < np; ++p) { for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_nelements(ps[p]) ; const int64_t ne = ggml_nelements(ps[p]);
for (int64_t j = 0; j < ne; ++j) { for (int64_t j = 0; j < ne; ++j) {
float x = ggml_get_f32_1d(ps[p], j); float x = ggml_get_f32_1d(ps[p], j);
float g = ggml_get_f32_1d(ps[p]->grad, j); float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
m[i] = m[i]*beta1 + g*(1.0f - beta1); m[i] = m[i]*beta1 + g*(1.0f - beta1);
v[i] = v[i]*beta2 + g*g*(1.0f - beta2); v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
float mh = m[i]*alpha/(1.0f - powf(beta1, opt->iter)); float mh = m[i]*beta1h;
float vh = v[i]*1.0f /(1.0f - powf(beta2, opt->iter)); float vh = v[i]*beta2h;
vh = sqrtf(vh) + eps; vh = sqrtf(vh) + eps;
x = x*(1.0f - decay) - mh/vh; x = x*(1.0f - decay) - mh/vh;
ggml_set_f32_1d(ps[p], j, x); ggml_set_f32_1d(ps[p], j, x);
@ -17902,6 +17921,7 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
.eps = 1e-8f, .eps = 1e-8f,
.eps_f = 1e-5f, .eps_f = 1e-5f,
.eps_g = 1e-3f, .eps_g = 1e-3f,
.gclip = 0.0f,
}, },
}; };
} break; } break;

1
ggml.h
View file

@ -1509,6 +1509,7 @@ extern "C" {
float eps; // epsilon for numerical stability float eps; // epsilon for numerical stability
float eps_f; // epsilon for convergence test float eps_f; // epsilon for convergence test
float eps_g; // epsilon for convergence test float eps_g; // epsilon for convergence test
float gclip; // gradient clipping
} adam; } adam;
// LBFGS parameters // LBFGS parameters