make rms_norm_eps a parameter (#2374)
* make rms_norm_eps a parameter * add rms_norm_eps to command line * fix baby llama, test-grad0 * use scientific notation for eps param in the help ggml-ci
This commit is contained in:
parent
b3f138d058
commit
41c674161f
11 changed files with 89 additions and 56 deletions
20
llama.cpp
20
llama.cpp
|
@ -186,6 +186,7 @@ struct llama_hparams {
|
|||
// LLaMAv2
|
||||
// TODO: load from model data hparams
|
||||
float f_ffn_mult = 1.0f;
|
||||
float f_rms_norm_eps = 1e-6f;
|
||||
|
||||
float rope_freq_base = 10000.0f;
|
||||
float rope_freq_scale = 1.0f;
|
||||
|
@ -869,6 +870,7 @@ struct llama_context_params llama_context_default_params() {
|
|||
/*.n_ctx =*/ 512,
|
||||
/*.n_batch =*/ 512,
|
||||
/*.n_gqa =*/ 1,
|
||||
/*.rms_norm_eps =*/ 1e-6f,
|
||||
/*.gpu_layers =*/ 0,
|
||||
/*.main_gpu =*/ 0,
|
||||
/*.tensor_split =*/ nullptr,
|
||||
|
@ -1000,6 +1002,7 @@ static void llama_model_load_internal(
|
|||
int n_ctx,
|
||||
int n_batch,
|
||||
int n_gqa,
|
||||
float rms_norm_eps,
|
||||
int n_gpu_layers,
|
||||
int main_gpu,
|
||||
const float * tensor_split,
|
||||
|
@ -1024,6 +1027,9 @@ static void llama_model_load_internal(
|
|||
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
// TODO: read from file
|
||||
hparams.f_rms_norm_eps = rms_norm_eps;
|
||||
|
||||
{
|
||||
switch (hparams.n_layer) {
|
||||
case 26: model.type = e_model::MODEL_3B; break;
|
||||
|
@ -1072,6 +1078,7 @@ static void llama_model_load_internal(
|
|||
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
|
||||
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
|
||||
fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa());
|
||||
fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps);
|
||||
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
|
||||
fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
|
||||
fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
|
||||
|
@ -1330,6 +1337,7 @@ static bool llama_model_load(
|
|||
int n_ctx,
|
||||
int n_batch,
|
||||
int n_gqa,
|
||||
float rms_norm_eps,
|
||||
int n_gpu_layers,
|
||||
int main_gpu,
|
||||
const float * tensor_split,
|
||||
|
@ -1343,7 +1351,7 @@ static bool llama_model_load(
|
|||
llama_progress_callback progress_callback,
|
||||
void *progress_callback_user_data) {
|
||||
try {
|
||||
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
|
||||
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
|
||||
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
|
||||
return true;
|
||||
} catch (const std::exception & err) {
|
||||
|
@ -1396,10 +1404,12 @@ static bool llama_eval_internal(
|
|||
const int64_t n_vocab = hparams.n_vocab;
|
||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
|
||||
LLAMA_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
const float freq_base = hparams.rope_freq_base;
|
||||
const float freq_scale = hparams.rope_freq_scale;
|
||||
const float rms_norm_eps = hparams.f_rms_norm_eps;
|
||||
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
||||
|
@ -1479,7 +1489,7 @@ static bool llama_eval_internal(
|
|||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "rms_norm_0");
|
||||
|
||||
|
@ -1627,7 +1637,7 @@ static bool llama_eval_internal(
|
|||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "rms_norm_1");
|
||||
|
||||
|
@ -1680,7 +1690,7 @@ static bool llama_eval_internal(
|
|||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
offload_func_nr(cur);
|
||||
ggml_set_name(cur, "rms_norm_2");
|
||||
|
||||
|
@ -3084,7 +3094,7 @@ struct llama_model * llama_load_model_from_file(
|
|||
|
||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers,
|
||||
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
|
||||
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
|
||||
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
|
||||
params.progress_callback_user_data)) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue