handle rms_norm and rope parameters the same as in train-text-from-scratch
This commit is contained in:
parent
007280c82f
commit
1faee64db9
1 changed files with 55 additions and 6 deletions
|
@ -176,8 +176,6 @@ struct my_llama_hparams {
|
|||
uint32_t n_layer = 32;
|
||||
uint32_t n_rot = 64;
|
||||
|
||||
float f_rms_norm_eps = 1e-5f;
|
||||
|
||||
bool operator!=(const my_llama_hparams& other) const {
|
||||
return memcmp(this, &other, sizeof(other));
|
||||
}
|
||||
|
@ -229,6 +227,12 @@ struct my_llama_lora_hparams {
|
|||
uint32_t n_rank_norm = 1;
|
||||
uint32_t n_rank_output = 4;
|
||||
|
||||
// float f_norm_eps = 1e-5f; // falcon
|
||||
float f_norm_rms_eps = 1e-5f; // llama
|
||||
|
||||
float rope_freq_base = 10000.0f;
|
||||
float rope_freq_scale = 1.0f;
|
||||
|
||||
bool operator!=(const my_llama_lora_hparams& other) const {
|
||||
return memcmp(this, &other, sizeof(other));
|
||||
}
|
||||
|
@ -769,8 +773,9 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
const int n_head = hparams.n_head;
|
||||
const int n_rot = hparams.n_rot;
|
||||
const int n_ff = hparams.n_ff;
|
||||
const float rms_norm_eps = hparams.f_rms_norm_eps;
|
||||
const int rope_mode = 0;
|
||||
const float rms_norm_eps = lora->hparams.f_norm_rms_eps;
|
||||
const float rope_freq_base = lora->hparams.rope_freq_base;
|
||||
const float rope_freq_scale = lora->hparams.rope_freq_scale;
|
||||
|
||||
GGML_ASSERT(n_layer == lora->layers.size());
|
||||
|
||||
|
@ -781,6 +786,18 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
}
|
||||
};
|
||||
|
||||
// rope has so much parameters that we make a custom function for it
|
||||
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
|
||||
(struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
// not capturing these, to silcence warnings
|
||||
const int n_past = 0;
|
||||
const int rope_mode = 0;
|
||||
|
||||
return ggml_rope_custom(ctx,
|
||||
t, n_past, n_rot, rope_mode, n_ctx,
|
||||
rope_freq_base, rope_freq_scale);
|
||||
};
|
||||
|
||||
set_name(tokens_input, "tokens_input");
|
||||
set_name(targets, "targets");
|
||||
|
||||
|
@ -834,10 +851,10 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t05 = ggml_mul_mat (ctx, wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd/n_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
|
||||
struct ggml_tensor * t07 = ggml_rope_inplace (ctx, t06, n_past, n_rot, rope_mode, n_ctx); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
|
||||
struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
|
||||
struct ggml_tensor * t08 = ggml_mul_mat (ctx, wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd/n_head, n_head, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
|
||||
struct ggml_tensor * t10 = ggml_rope_inplace (ctx, t09, n_past, n_rot, rope_mode, n_ctx); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
|
||||
struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
|
||||
|
||||
struct ggml_tensor * t11;
|
||||
if (ggml_is_quantized(wv->type)) {
|
||||
|
@ -1631,6 +1648,10 @@ struct train_params {
|
|||
int n_batch;
|
||||
int n_examples;
|
||||
|
||||
float f_norm_rms_eps;
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
||||
int32_t lora_r;
|
||||
int32_t lora_alpha;
|
||||
|
||||
|
@ -1701,6 +1722,10 @@ struct train_params get_default_train_params() {
|
|||
params.n_batch = 8;
|
||||
params.n_examples = 1;
|
||||
|
||||
params.f_norm_rms_eps = 1e-5f;
|
||||
params.rope_freq_base = 10000.0f;
|
||||
params.rope_freq_scale = 1.0f;
|
||||
|
||||
params.lora_alpha = 4;
|
||||
params.lora_r = 4;
|
||||
|
||||
|
@ -1771,6 +1796,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads);
|
||||
fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch);
|
||||
fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples);
|
||||
fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
|
||||
fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
|
||||
fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
|
||||
fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
|
||||
fprintf(stderr, " --lora-r N LORA r : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_r);
|
||||
fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor (default %d)\n", params->n_rank_attention_norm);
|
||||
|
@ -1910,6 +1938,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->n_examples = std::stoi(argv[i]);
|
||||
} else if (arg == "--norm-rms-eps") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->f_norm_rms_eps = std::stof(argv[i]);
|
||||
} else if (arg == "--rope-freq-base") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->rope_freq_base = std::stof(argv[i]);
|
||||
} else if (arg == "--rope-freq-scale") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->rope_freq_scale = std::stof(argv[i]);
|
||||
} else if (arg == "--lora-alpha") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -2290,6 +2336,9 @@ int main(int argc, char ** argv) {
|
|||
init_model(lmodel, &model, params.n_ctx);
|
||||
|
||||
struct my_llama_lora lora;
|
||||
lora.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
|
||||
lora.hparams.rope_freq_base = params.rope_freq_base;
|
||||
lora.hparams.rope_freq_scale = params.rope_freq_scale;
|
||||
lora.hparams.lora_r = params.lora_r;
|
||||
lora.hparams.lora_alpha = params.lora_alpha;
|
||||
lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue