handle rms_norm and rope parameters the same as in train-text-from-scratch

This commit is contained in:
xaedes 2023-08-29 01:09:35 +02:00
parent 007280c82f
commit 1faee64db9
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -176,8 +176,6 @@ struct my_llama_hparams {
uint32_t n_layer = 32;
uint32_t n_rot = 64;
float f_rms_norm_eps = 1e-5f;
bool operator!=(const my_llama_hparams& other) const {
return memcmp(this, &other, sizeof(other));
}
@ -229,6 +227,12 @@ struct my_llama_lora_hparams {
uint32_t n_rank_norm = 1;
uint32_t n_rank_output = 4;
// float f_norm_eps = 1e-5f; // falcon
float f_norm_rms_eps = 1e-5f; // llama
float rope_freq_base = 10000.0f;
float rope_freq_scale = 1.0f;
bool operator!=(const my_llama_lora_hparams& other) const {
return memcmp(this, &other, sizeof(other));
}
@ -769,8 +773,9 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
const int n_head = hparams.n_head;
const int n_rot = hparams.n_rot;
const int n_ff = hparams.n_ff;
const float rms_norm_eps = hparams.f_rms_norm_eps;
const int rope_mode = 0;
const float rms_norm_eps = lora->hparams.f_norm_rms_eps;
const float rope_freq_base = lora->hparams.rope_freq_base;
const float rope_freq_scale = lora->hparams.rope_freq_scale;
GGML_ASSERT(n_layer == lora->layers.size());
@ -781,6 +786,18 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
}
};
// rope has so much parameters that we make a custom function for it
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
(struct ggml_tensor * t) -> struct ggml_tensor * {
// not capturing these, to silcence warnings
const int n_past = 0;
const int rope_mode = 0;
return ggml_rope_custom(ctx,
t, n_past, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
};
set_name(tokens_input, "tokens_input");
set_name(targets, "targets");
@ -834,10 +851,10 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch);
struct ggml_tensor * t05 = ggml_mul_mat (ctx, wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch);
struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t07 = ggml_rope_inplace (ctx, t06, n_past, n_rot, rope_mode, n_ctx); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t08 = ggml_mul_mat (ctx, wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd, N*n_batch);
struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t10 = ggml_rope_inplace (ctx, t09, n_past, n_rot, rope_mode, n_ctx); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
struct ggml_tensor * t11;
if (ggml_is_quantized(wv->type)) {
@ -1631,6 +1648,10 @@ struct train_params {
int n_batch;
int n_examples;
float f_norm_rms_eps;
float rope_freq_base;
float rope_freq_scale;
int32_t lora_r;
int32_t lora_alpha;
@ -1701,6 +1722,10 @@ struct train_params get_default_train_params() {
params.n_batch = 8;
params.n_examples = 1;
params.f_norm_rms_eps = 1e-5f;
params.rope_freq_base = 10000.0f;
params.rope_freq_scale = 1.0f;
params.lora_alpha = 4;
params.lora_r = 4;
@ -1771,6 +1796,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
fprintf(stderr, " --lora-r N LORA r : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_r);
fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor (default %d)\n", params->n_rank_attention_norm);
@ -1910,6 +1938,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break;
}
params->n_examples = std::stoi(argv[i]);
} else if (arg == "--norm-rms-eps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->f_norm_rms_eps = std::stof(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->rope_freq_base = std::stof(argv[i]);
} else if (arg == "--rope-freq-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--lora-alpha") {
if (++i >= argc) {
invalid_param = true;
@ -2290,6 +2336,9 @@ int main(int argc, char ** argv) {
init_model(lmodel, &model, params.n_ctx);
struct my_llama_lora lora;
lora.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
lora.hparams.rope_freq_base = params.rope_freq_base;
lora.hparams.rope_freq_scale = params.rope_freq_scale;
lora.hparams.lora_r = params.lora_r;
lora.hparams.lora_alpha = params.lora_alpha;
lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm;