make rms_norm_eps a parameter (#2374)
* make rms_norm_eps a parameter * add rms_norm_eps to command line * fix baby llama, test-grad0 * use scientific notation for eps param in the help ggml-ci
This commit is contained in:
parent
b3f138d058
commit
41c674161f
11 changed files with 89 additions and 56 deletions
|
@ -8,6 +8,8 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static const float rms_norm_eps = 1e-6f;
|
||||
|
||||
float frand() {
|
||||
return (float)rand()/(float)RAND_MAX;
|
||||
}
|
||||
|
@ -562,7 +564,7 @@ struct ggml_tensor * forward(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
|
@ -685,7 +687,7 @@ struct ggml_tensor * forward(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
// cur shape [n_embd,N,1,1]
|
||||
|
@ -729,7 +731,7 @@ struct ggml_tensor * forward(
|
|||
{
|
||||
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
|
||||
// inpL = norm*inpL
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
|
@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
|
@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
|
@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch(
|
|||
{
|
||||
|
||||
// inpL shape [n_embd,N*n_batch,1,1]
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
|
||||
// inpL = norm*inpL
|
||||
|
@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
|
@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
// cur shape [n_embd,N,1,1]
|
||||
|
@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora(
|
|||
{
|
||||
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
|
||||
// inpL = norm*inpL
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
|
|
|
@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.n_gqa = std::stoi(argv[i]);
|
||||
} else if (arg == "-eps" || arg == "--rms-norm-eps") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.rms_norm_eps = std::stof(argv[i]);
|
||||
} else if (arg == "--rope-freq-base") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -519,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
||||
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
|
||||
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
|
||||
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
|
||||
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
|
||||
|
@ -615,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
|||
lparams.n_ctx = params.n_ctx;
|
||||
lparams.n_batch = params.n_batch;
|
||||
lparams.n_gqa = params.n_gqa;
|
||||
lparams.rms_norm_eps = params.rms_norm_eps;
|
||||
lparams.n_gpu_layers = params.n_gpu_layers;
|
||||
lparams.main_gpu = params.main_gpu;
|
||||
lparams.tensor_split = params.tensor_split;
|
||||
|
|
|
@ -22,18 +22,19 @@
|
|||
int32_t get_num_physical_cores();
|
||||
|
||||
struct gpt_params {
|
||||
uint32_t seed = -1; // RNG seed
|
||||
uint32_t seed = -1; // RNG seed
|
||||
int32_t n_threads = get_num_physical_cores();
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 512; // context size
|
||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t n_predict = -1; // new tokens to predict
|
||||
int32_t n_ctx = 512; // context size
|
||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
float rms_norm_eps = 1e-6; // rms norm epsilon
|
||||
float rope_freq_base = 10000.0f; // RoPE base frequency
|
||||
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static const float rms_norm_eps = 1e-6f;
|
||||
|
||||
struct random_normal_distribution {
|
||||
std::mt19937 gen;
|
||||
std::normal_distribution<float> rd;
|
||||
|
@ -439,7 +441,7 @@ struct ggml_tensor * forward(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
cur = ggml_mul(ctx0,
|
||||
|
@ -562,7 +564,7 @@ struct ggml_tensor * forward(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
// cur shape [n_embd,N,1,1]
|
||||
|
@ -606,7 +608,7 @@ struct ggml_tensor * forward(
|
|||
{
|
||||
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
|
||||
// inpL = norm*inpL
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
|
@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
|
@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
|
@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
|
|||
{
|
||||
|
||||
// inpL shape [n_embd,N*n_batch,1,1]
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
|
||||
// inpL = norm*inpL
|
||||
|
@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
|
@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
|||
// norm
|
||||
{
|
||||
// cur shape [n_embd,N*n_batch,1,1]
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
|
@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
|||
{
|
||||
|
||||
// inpL shape [n_embd,N*n_batch,1,1]
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
|
||||
// inpL = norm*inpL
|
||||
|
@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
|||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = attention_norm*cur
|
||||
|
@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
|||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||
|
||||
// cur = ffn_norm*cur
|
||||
|
@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
|||
// norm
|
||||
{
|
||||
|
||||
inpL = ggml_rms_norm(ctx0, inpL);
|
||||
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
|
||||
// inpL = norm*inpL
|
||||
|
@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
struct my_llama_layer & layer = model->layers[il];
|
||||
// tensors with values necessary for backward pass are in persistent buf(-1)
|
||||
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
|
||||
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||
|
@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
|
||||
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||
|
@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
}
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
|
||||
use_buf(-1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue