add gradient clipping to AdamW
This commit is contained in:
parent
d39c8e6863
commit
d395b19c8c
3 changed files with 98 additions and 9 deletions
|
@ -2430,7 +2430,8 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt)
|
||||||
GGML_ASSERT(opt->nx >= 0);
|
GGML_ASSERT(opt->nx >= 0);
|
||||||
GGML_ASSERT(opt->iter >= 0);
|
GGML_ASSERT(opt->iter >= 0);
|
||||||
file->write_u32(version);
|
file->write_u32(version);
|
||||||
file->write_raw(&opt->params, sizeof(opt->params));
|
file->write_u32(opt->params.past);
|
||||||
|
file->write_u32(opt->params.lbfgs.m);
|
||||||
file->write_raw(&opt->nx, sizeof(opt->nx));
|
file->write_raw(&opt->nx, sizeof(opt->nx));
|
||||||
file->write_raw(&opt->iter, sizeof(opt->iter));
|
file->write_raw(&opt->iter, sizeof(opt->iter));
|
||||||
file->write_u32((uint32_t) opt->just_initialized);
|
file->write_u32((uint32_t) opt->just_initialized);
|
||||||
|
@ -2469,8 +2470,43 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_opt_params_v0 {
|
||||||
|
enum ggml_opt_type type;
|
||||||
|
int n_threads;
|
||||||
|
int past;
|
||||||
|
float delta;
|
||||||
|
int max_no_improvement;
|
||||||
|
bool print_forward_graph;
|
||||||
|
bool print_backward_graph;
|
||||||
|
struct {
|
||||||
|
int n_iter;
|
||||||
|
float sched;
|
||||||
|
float decay;
|
||||||
|
float alpha;
|
||||||
|
float beta1;
|
||||||
|
float beta2;
|
||||||
|
float eps;
|
||||||
|
float eps_f;
|
||||||
|
float eps_g;
|
||||||
|
} adam;
|
||||||
|
struct {
|
||||||
|
int m;
|
||||||
|
int n_iter;
|
||||||
|
int max_linesearch;
|
||||||
|
float eps;
|
||||||
|
float ftol;
|
||||||
|
float wolfe;
|
||||||
|
float min_step;
|
||||||
|
float max_step;
|
||||||
|
enum ggml_linesearch linesearch;
|
||||||
|
} lbfgs;
|
||||||
|
};
|
||||||
|
|
||||||
void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
|
void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
|
||||||
file->read_raw(&opt->params, sizeof(opt->params));
|
ggml_opt_params_v0 pv0;
|
||||||
|
file->read_raw(&pv0, sizeof(pv0));
|
||||||
|
opt->params.past = pv0.past;
|
||||||
|
opt->params.lbfgs.m = pv0.lbfgs.m;
|
||||||
file->read_raw(&opt->nx, sizeof(opt->nx));
|
file->read_raw(&opt->nx, sizeof(opt->nx));
|
||||||
ggml_opt_init(ctx, opt, opt->params, opt->nx);
|
ggml_opt_init(ctx, opt, opt->params, opt->nx);
|
||||||
|
|
||||||
|
@ -2516,7 +2552,8 @@ void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, st
|
||||||
}
|
}
|
||||||
|
|
||||||
void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
|
void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
|
||||||
file->read_raw(&opt->params, sizeof(opt->params));
|
opt->params.past = (int) file->read_u32();
|
||||||
|
opt->params.lbfgs.m = (int) file->read_u32();
|
||||||
file->read_raw(&opt->nx, sizeof(opt->nx));
|
file->read_raw(&opt->nx, sizeof(opt->nx));
|
||||||
ggml_opt_init(ctx, opt, opt->params, opt->nx);
|
ggml_opt_init(ctx, opt, opt->params, opt->nx);
|
||||||
|
|
||||||
|
@ -2558,6 +2595,7 @@ void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, st
|
||||||
|
|
||||||
void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
|
void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) {
|
||||||
uint32_t version = file->read_u32();
|
uint32_t version = file->read_u32();
|
||||||
|
printf("%s: opt context version %u\n", __func__, version);
|
||||||
switch (version) {
|
switch (version) {
|
||||||
case 0:
|
case 0:
|
||||||
{
|
{
|
||||||
|
@ -2569,7 +2607,7 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
fprintf(stderr, "%s: unknown version %ud\n", __func__, version);
|
fprintf(stderr, "%s: unknown version %u\n", __func__, version);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2783,6 +2821,9 @@ struct train_params {
|
||||||
int adam_n_iter;
|
int adam_n_iter;
|
||||||
float adam_alpha;
|
float adam_alpha;
|
||||||
float adam_decay;
|
float adam_decay;
|
||||||
|
float adam_beta1;
|
||||||
|
float adam_beta2;
|
||||||
|
float adam_gclip;
|
||||||
|
|
||||||
int mem_model_gb;
|
int mem_model_gb;
|
||||||
int mem_compute_gb;
|
int mem_compute_gb;
|
||||||
|
@ -2830,6 +2871,9 @@ struct train_params get_default_train_params() {
|
||||||
params.adam_n_iter = 16;
|
params.adam_n_iter = 16;
|
||||||
params.adam_alpha = 1e-3f;
|
params.adam_alpha = 1e-3f;
|
||||||
params.adam_decay = 1e-3f;
|
params.adam_decay = 1e-3f;
|
||||||
|
params.adam_beta1 = 0.9f;
|
||||||
|
params.adam_beta2 = 0.999f;
|
||||||
|
params.adam_gclip = 1.0f;
|
||||||
|
|
||||||
params.mem_model_gb = 2;
|
params.mem_model_gb = 2;
|
||||||
params.mem_compute_gb = 24;
|
params.mem_compute_gb = 24;
|
||||||
|
@ -2877,6 +2921,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
||||||
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
||||||
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
||||||
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
||||||
|
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
|
||||||
|
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
|
||||||
|
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
|
||||||
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
|
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
|
||||||
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
||||||
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
||||||
|
@ -3066,6 +3113,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->adam_decay = std::stof(argv[i]);
|
params->adam_decay = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--adam-beta1") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->adam_beta1 = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--adam-beta2") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->adam_beta2 = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--adam-gclip") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params->adam_gclip = std::stof(argv[i]);
|
||||||
} else if (arg == "--mem-model") {
|
} else if (arg == "--mem-model") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -3212,6 +3277,9 @@ int main(int argc, char ** argv) {
|
||||||
opt_params_adam.adam.sched = 1.0f;
|
opt_params_adam.adam.sched = 1.0f;
|
||||||
opt_params_adam.adam.alpha = params.adam_alpha;
|
opt_params_adam.adam.alpha = params.adam_alpha;
|
||||||
opt_params_adam.adam.decay = params.adam_decay;
|
opt_params_adam.adam.decay = params.adam_decay;
|
||||||
|
opt_params_adam.adam.beta1 = params.adam_beta1;
|
||||||
|
opt_params_adam.adam.beta2 = params.adam_beta2;
|
||||||
|
opt_params_adam.adam.gclip = params.adam_gclip;
|
||||||
|
|
||||||
opt_params_lbfgs.print_forward_graph = false;
|
opt_params_lbfgs.print_forward_graph = false;
|
||||||
opt_params_lbfgs.print_backward_graph = false;
|
opt_params_lbfgs.print_backward_graph = false;
|
||||||
|
|
28
ggml.c
28
ggml.c
|
@ -17354,6 +17354,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||||
const float beta1 = params.adam.beta1;
|
const float beta1 = params.adam.beta1;
|
||||||
const float beta2 = params.adam.beta2;
|
const float beta2 = params.adam.beta2;
|
||||||
const float eps = params.adam.eps;
|
const float eps = params.adam.eps;
|
||||||
|
const float gclip = params.adam.gclip;
|
||||||
|
|
||||||
float * m = opt->adam.m->data; // first moment
|
float * m = opt->adam.m->data; // first moment
|
||||||
float * v = opt->adam.v->data; // second moment
|
float * v = opt->adam.v->data; // second moment
|
||||||
|
@ -17404,16 +17405,34 @@ static enum ggml_opt_result ggml_opt_adam(
|
||||||
UNUSED(t_start_cpu);
|
UNUSED(t_start_cpu);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
float gnorm = 1.0f;
|
||||||
|
if (gclip > 0.0f) {
|
||||||
|
// gradient clipping
|
||||||
|
ggml_float sum = 0.0;
|
||||||
|
for (int p = 0; p < np; ++p) {
|
||||||
|
const int64_t ne = ggml_nelements(ps[p]);
|
||||||
|
for (int64_t j = 0; j < ne; ++j) {
|
||||||
|
float g = ggml_get_f32_1d(ps[p]->grad, j);
|
||||||
|
sum += g*g;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ggml_float norm = sqrt(sum);
|
||||||
|
if (norm > (ggml_float) gclip) {
|
||||||
|
gnorm = (float) ((ggml_float) gclip / norm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const float beta1h = alpha/(1.0f - powf(beta1, opt->iter));
|
||||||
|
const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter));
|
||||||
int64_t i = 0;
|
int64_t i = 0;
|
||||||
for (int p = 0; p < np; ++p) {
|
for (int p = 0; p < np; ++p) {
|
||||||
const int64_t ne = ggml_nelements(ps[p]) ;
|
const int64_t ne = ggml_nelements(ps[p]);
|
||||||
for (int64_t j = 0; j < ne; ++j) {
|
for (int64_t j = 0; j < ne; ++j) {
|
||||||
float x = ggml_get_f32_1d(ps[p], j);
|
float x = ggml_get_f32_1d(ps[p], j);
|
||||||
float g = ggml_get_f32_1d(ps[p]->grad, j);
|
float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
|
||||||
m[i] = m[i]*beta1 + g*(1.0f - beta1);
|
m[i] = m[i]*beta1 + g*(1.0f - beta1);
|
||||||
v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
|
v[i] = v[i]*beta2 + g*g*(1.0f - beta2);
|
||||||
float mh = m[i]*alpha/(1.0f - powf(beta1, opt->iter));
|
float mh = m[i]*beta1h;
|
||||||
float vh = v[i]*1.0f /(1.0f - powf(beta2, opt->iter));
|
float vh = v[i]*beta2h;
|
||||||
vh = sqrtf(vh) + eps;
|
vh = sqrtf(vh) + eps;
|
||||||
x = x*(1.0f - decay) - mh/vh;
|
x = x*(1.0f - decay) - mh/vh;
|
||||||
ggml_set_f32_1d(ps[p], j, x);
|
ggml_set_f32_1d(ps[p], j, x);
|
||||||
|
@ -17902,6 +17921,7 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
||||||
.eps = 1e-8f,
|
.eps = 1e-8f,
|
||||||
.eps_f = 1e-5f,
|
.eps_f = 1e-5f,
|
||||||
.eps_g = 1e-3f,
|
.eps_g = 1e-3f,
|
||||||
|
.gclip = 0.0f,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
|
|
1
ggml.h
1
ggml.h
|
@ -1509,6 +1509,7 @@ extern "C" {
|
||||||
float eps; // epsilon for numerical stability
|
float eps; // epsilon for numerical stability
|
||||||
float eps_f; // epsilon for convergence test
|
float eps_f; // epsilon for convergence test
|
||||||
float eps_g; // epsilon for convergence test
|
float eps_g; // epsilon for convergence test
|
||||||
|
float gclip; // gradient clipping
|
||||||
} adam;
|
} adam;
|
||||||
|
|
||||||
// LBFGS parameters
|
// LBFGS parameters
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue