diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 88439cfe3..8f35fe2c9 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -18,8 +18,6 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static const float rms_norm_eps = 1e-5f; - struct random_normal_distribution { std::mt19937 gen; std::normal_distribution rd; @@ -502,6 +500,7 @@ struct ggml_tensor * forward( const int n_layer = hparams.n_layer; const int n_head = hparams.n_head; const int n_rot = hparams.n_rot; + const float rms_norm_eps = hparams.f_norm_rms_eps; struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(tokens->data, tokens_input->data, N*ggml_element_size(tokens)); @@ -959,6 +958,9 @@ struct ggml_tensor * llama_build_train_graphs( const int n_head = hparams.n_head; const int n_rot = hparams.n_rot; const int n_ff = hparams.n_ff; + const float f_norm_rms_eps = hparams.f_norm_rms_eps; + const float rope_freq_base = hparams.rope_freq_base; + const float rope_freq_scale = hparams.rope_freq_scale; const int rope_mode = 0; auto set_name = [](struct ggml_tensor * t, const char * n) { @@ -968,6 +970,14 @@ struct ggml_tensor * llama_build_train_graphs( } }; + // rope has so much parameters that we make a custom function for it + auto rope = [ctx, n_past, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale] + (struct ggml_tensor * t) -> struct ggml_tensor * { + return ggml_rope_custom(ctx, + t, n_past, n_rot, rope_mode, n_ctx, + rope_freq_base, rope_freq_scale); + }; + set_name(tokens_input, "tokens_input"); set_name(targets, "targets"); @@ -990,15 +1000,15 @@ struct ggml_tensor * llama_build_train_graphs( for (int il = 0; il < n_layer; ++il) { struct my_llama_layer & layer = model->layers[il]; - struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch); + struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch); struct ggml_tensor * t03 = ggml_repeat (ctx, layer.attention_norm, t02); set_name(t03, "t03"); assert_shape_2d(t03, n_embd, N*n_batch); struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch); struct ggml_tensor * t05 = ggml_mul_mat (ctx, layer.wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch); struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch); - struct ggml_tensor * t07 = ggml_rope_inplace (ctx, t06, n_past, n_rot, rope_mode, n_ctx); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch); struct ggml_tensor * t08 = ggml_mul_mat (ctx, layer.wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd, N*n_batch); struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch); - struct ggml_tensor * t10 = ggml_rope_inplace (ctx, t09, n_past, n_rot, rope_mode, n_ctx); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch); + struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch); struct ggml_tensor * t11 = ggml_mul_mat (ctx, t04, layer.wv); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd); struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd/n_head, n_head); set_name(t12, "t12"); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head); struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); set_name(t13, "t13"); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch); @@ -1019,7 +1029,7 @@ struct ggml_tensor * llama_build_train_graphs( struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); set_name(t19, "t19"); assert_shape_2d(t19, n_embd, N*n_batch); struct ggml_tensor * t20 = ggml_mul_mat (ctx, layer.wo, t19); set_name(t20, "t20"); assert_shape_2d(t20, n_embd, N*n_batch); struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); set_name(t21, "t21"); assert_shape_2d(t21, n_embd, N*n_batch); - struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, rms_norm_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch); + struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, f_norm_rms_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch); struct ggml_tensor * t23 = ggml_repeat (ctx, layer.ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch); struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch); struct ggml_tensor * t25 = ggml_mul_mat (ctx, layer.w3, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch); @@ -1031,7 +1041,7 @@ struct ggml_tensor * llama_build_train_graphs( cur = t30; checkpoints.push_back(cur); } - struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch); + struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, f_norm_rms_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch); struct ggml_tensor * t32 = ggml_repeat (ctx, model->norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch); struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); set_name(t33, "t33"); assert_shape_2d(t33, n_embd, N*n_batch); struct ggml_tensor * t34 = ggml_mul_mat (ctx, model->output, t33); set_name(t34, "t34"); assert_shape_2d(t34, n_vocab, N*n_batch); @@ -1960,6 +1970,10 @@ struct train_params { int n_examples; int n_predict; + float f_norm_rms_eps; + float rope_freq_base; + float rope_freq_scale; + int print_info_interval; int print_details_interval; @@ -2017,6 +2031,10 @@ struct train_params get_default_train_params() { params.n_examples = 1; params.n_predict = 1024; + params.f_norm_rms_eps = 1e-5; + params.rope_freq_base = 10000.0f; + params.rope_freq_scale = 1.0f; + params.print_info_interval = 1; params.print_details_interval = 2; @@ -2070,6 +2088,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --ff N Feedforward size used for new models. (default %d)\n", params->n_ff); fprintf(stderr, " --head N Number of heads for new models (default %d)\n", params->n_head); fprintf(stderr, " --layer N Number of layers for new models (default %d)\n", params->n_layer); + fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps); + fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); + fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples); @@ -2188,6 +2209,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_layer = std::stoi(argv[i]); + } else if (arg == "--norm-rms-eps") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->f_norm_rms_eps = std::stof(argv[i]); + } else if (arg == "--rope-freq-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->rope_freq_base = std::stof(argv[i]); + } else if (arg == "--rope-freq-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->rope_freq_scale = std::stof(argv[i]); } else if (arg == "-t" || arg == "--threads") { if (++i >= argc) { invalid_param = true; @@ -2476,6 +2515,9 @@ int main(int argc, char ** argv) { model.hparams.n_ff = params.n_ff; // llama.cpp requires n_rot to be exactly n_embd / n_head model.hparams.n_rot = model.hparams.n_embd / model.hparams.n_head; + model.hparams.f_norm_rms_eps = params.f_norm_rms_eps; + model.hparams.rope_freq_base = params.rope_freq_base; + model.hparams.rope_freq_scale = params.rope_freq_scale; print_params(&model.hparams);