diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 60d2b5783..a4a6b05b1 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2430,7 +2430,8 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) GGML_ASSERT(opt->nx >= 0); GGML_ASSERT(opt->iter >= 0); file->write_u32(version); - file->write_raw(&opt->params, sizeof(opt->params)); + file->write_u32(opt->params.past); + file->write_u32(opt->params.lbfgs.m); file->write_raw(&opt->nx, sizeof(opt->nx)); file->write_raw(&opt->iter, sizeof(opt->iter)); file->write_u32((uint32_t) opt->just_initialized); @@ -2469,9 +2470,44 @@ void write_opt_context(struct llama_file * file, struct ggml_opt_context * opt) } } +struct ggml_opt_params_v0 { + enum ggml_opt_type type; + int n_threads; + int past; + float delta; + int max_no_improvement; + bool print_forward_graph; + bool print_backward_graph; + struct { + int n_iter; + float sched; + float decay; + float alpha; + float beta1; + float beta2; + float eps; + float eps_f; + float eps_g; + } adam; + struct { + int m; + int n_iter; + int max_linesearch; + float eps; + float ftol; + float wolfe; + float min_step; + float max_step; + enum ggml_linesearch linesearch; + } lbfgs; +}; + void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { - file->read_raw(&opt->params, sizeof(opt->params)); - file->read_raw(&opt->nx, sizeof(opt->nx)); + ggml_opt_params_v0 pv0; + file->read_raw(&pv0, sizeof(pv0)); + opt->params.past = pv0.past; + opt->params.lbfgs.m = pv0.lbfgs.m; + file->read_raw(&opt->nx, sizeof(opt->nx)); ggml_opt_init(ctx, opt, opt->params, opt->nx); file->read_raw(&opt->iter, sizeof(opt->iter)); @@ -2516,7 +2552,8 @@ void read_opt_context_v0(struct llama_file * file, struct ggml_context * ctx, st } void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { - file->read_raw(&opt->params, sizeof(opt->params)); + opt->params.past = (int) file->read_u32(); + opt->params.lbfgs.m = (int) file->read_u32(); file->read_raw(&opt->nx, sizeof(opt->nx)); ggml_opt_init(ctx, opt, opt->params, opt->nx); @@ -2558,6 +2595,7 @@ void read_opt_context_v1(struct llama_file * file, struct ggml_context * ctx, st void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struct ggml_opt_context * opt) { uint32_t version = file->read_u32(); + printf("%s: opt context version %u\n", __func__, version); switch (version) { case 0: { @@ -2569,7 +2607,7 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc } break; default: { - fprintf(stderr, "%s: unknown version %ud\n", __func__, version); + fprintf(stderr, "%s: unknown version %u\n", __func__, version); } } } @@ -2783,6 +2821,9 @@ struct train_params { int adam_n_iter; float adam_alpha; float adam_decay; + float adam_beta1; + float adam_beta2; + float adam_gclip; int mem_model_gb; int mem_compute_gb; @@ -2830,6 +2871,9 @@ struct train_params get_default_train_params() { params.adam_n_iter = 16; params.adam_alpha = 1e-3f; params.adam_decay = 1e-3f; + params.adam_beta1 = 0.9f; + params.adam_beta2 = 0.999f; + params.adam_gclip = 1.0f; params.mem_model_gb = 2; params.mem_compute_gb = 24; @@ -2877,6 +2921,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); + fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); + fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); + fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb); fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb); fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb); @@ -3066,6 +3113,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->adam_decay = std::stof(argv[i]); + } else if (arg == "--adam-beta1") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_beta1 = std::stof(argv[i]); + } else if (arg == "--adam-beta2") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_beta2 = std::stof(argv[i]); + } else if (arg == "--adam-gclip") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_gclip = std::stof(argv[i]); } else if (arg == "--mem-model") { if (++i >= argc) { invalid_param = true; @@ -3212,6 +3277,9 @@ int main(int argc, char ** argv) { opt_params_adam.adam.sched = 1.0f; opt_params_adam.adam.alpha = params.adam_alpha; opt_params_adam.adam.decay = params.adam_decay; + opt_params_adam.adam.beta1 = params.adam_beta1; + opt_params_adam.adam.beta2 = params.adam_beta2; + opt_params_adam.adam.gclip = params.adam_gclip; opt_params_lbfgs.print_forward_graph = false; opt_params_lbfgs.print_backward_graph = false; diff --git a/ggml.c b/ggml.c index 143f88d4a..19a194beb 100644 --- a/ggml.c +++ b/ggml.c @@ -17354,6 +17354,7 @@ static enum ggml_opt_result ggml_opt_adam( const float beta1 = params.adam.beta1; const float beta2 = params.adam.beta2; const float eps = params.adam.eps; + const float gclip = params.adam.gclip; float * m = opt->adam.m->data; // first moment float * v = opt->adam.v->data; // second moment @@ -17404,16 +17405,34 @@ static enum ggml_opt_result ggml_opt_adam( UNUSED(t_start_cpu); { + float gnorm = 1.0f; + if (gclip > 0.0f) { + // gradient clipping + ggml_float sum = 0.0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]); + for (int64_t j = 0; j < ne; ++j) { + float g = ggml_get_f32_1d(ps[p]->grad, j); + sum += g*g; + } + } + ggml_float norm = sqrt(sum); + if (norm > (ggml_float) gclip) { + gnorm = (float) ((ggml_float) gclip / norm); + } + } + const float beta1h = alpha/(1.0f - powf(beta1, opt->iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); int64_t i = 0; for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]) ; + const int64_t ne = ggml_nelements(ps[p]); for (int64_t j = 0; j < ne; ++j) { float x = ggml_get_f32_1d(ps[p], j); - float g = ggml_get_f32_1d(ps[p]->grad, j); + float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm; m[i] = m[i]*beta1 + g*(1.0f - beta1); v[i] = v[i]*beta2 + g*g*(1.0f - beta2); - float mh = m[i]*alpha/(1.0f - powf(beta1, opt->iter)); - float vh = v[i]*1.0f /(1.0f - powf(beta2, opt->iter)); + float mh = m[i]*beta1h; + float vh = v[i]*beta2h; vh = sqrtf(vh) + eps; x = x*(1.0f - decay) - mh/vh; ggml_set_f32_1d(ps[p], j, x); @@ -17902,6 +17921,7 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { .eps = 1e-8f, .eps_f = 1e-5f, .eps_g = 1e-3f, + .gclip = 0.0f, }, }; } break; diff --git a/ggml.h b/ggml.h index 531b6cb07..460976468 100644 --- a/ggml.h +++ b/ggml.h @@ -1509,6 +1509,7 @@ extern "C" { float eps; // epsilon for numerical stability float eps_f; // epsilon for convergence test float eps_g; // epsilon for convergence test + float gclip; // gradient clipping } adam; // LBFGS parameters