make rms_norm_eps a parameter

This commit is contained in:
slaren 2023-07-24 16:18:40 +02:00
parent 5b2b2dc6ae
commit 9fe47c747f
6 changed files with 48 additions and 33 deletions

View file

@ -16,6 +16,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
static const float rms_norm_eps = 1e-6f;
struct random_normal_distribution { struct random_normal_distribution {
std::mt19937 gen; std::mt19937 gen;
std::normal_distribution<float> rd; std::normal_distribution<float> rd;
@ -439,7 +441,7 @@ struct ggml_tensor * forward(
// norm // norm
{ {
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// cur = attention_norm*cur // cur = attention_norm*cur
cur = ggml_mul(ctx0, cur = ggml_mul(ctx0,
@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm // norm
{ {
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
// cur = ffn_norm*cur // cur = ffn_norm*cur
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
@ -606,7 +608,7 @@ struct ggml_tensor * forward(
{ {
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// inpL = norm*inpL // inpL = norm*inpL
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]
@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur // cur = attention_norm*cur
@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur // cur = ffn_norm*cur
@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
{ {
// inpL shape [n_embd,N*n_batch,1,1] // inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL // inpL = norm*inpL
@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur // cur = attention_norm*cur
@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur // cur = ffn_norm*cur
@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
{ {
// inpL shape [n_embd,N*n_batch,1,1] // inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL // inpL = norm*inpL
@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur // cur = attention_norm*cur
@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
{ {
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur // cur = ffn_norm*cur
@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm // norm
{ {
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL // inpL = norm*inpL
@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
struct my_llama_layer & layer = model->layers[il]; struct my_llama_layer & layer = model->layers[il];
// tensors with values necessary for backward pass are in persistent buf(-1) // tensors with values necessary for backward pass are in persistent buf(-1)
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed. // other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch); use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
} }
clr_buf(0); clr_buf(0);
use_buf(0); use_buf(0);
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch); struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch); struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch); struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
use_buf(-1); use_buf(-1);

View file

@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
} }
} }
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const float eps = 1e-6f;
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += WARP_SIZE) { for (int col = tid; col < ncols; col += WARP_SIZE) {
@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols); norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
} }
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0); GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols); rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} }
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) { static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm(
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low; const int64_t i01_diff = i01_high - i01_low;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
// compute // compute
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
(void) src1; (void) src1;
(void) dst; (void) dst;

View file

@ -812,7 +812,8 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
const float eps = 1e-6f; float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int nth = 512; const int nth = 512;

16
ggml.c
View file

@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace(
static struct ggml_tensor * ggml_rms_norm_impl( static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
float eps,
bool inplace) { bool inplace) {
bool is_node = false; bool is_node = false;
@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
// TODO: maybe store epsilon here? ggml_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_OP_RMS_NORM; result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_tensor * ggml_rms_norm( struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a,
return ggml_rms_norm_impl(ctx, a, false); float eps) {
return ggml_rms_norm_impl(ctx, a, eps, false);
} }
struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a,
return ggml_rms_norm_impl(ctx, a, true); float eps) {
return ggml_rms_norm_impl(ctx, a, eps, true);
} }
struct ggml_tensor * ggml_rms_norm_back( struct ggml_tensor * ggml_rms_norm_back(
@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32(
GGML_TENSOR_UNARY_OP_LOCALS; GGML_TENSOR_UNARY_OP_LOCALS;
const float eps = 1e-6f; // TODO: make this a parameter float eps;
memcpy(&eps, dst->op_params, sizeof(float));
// TODO: optimize // TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {

7
ggml.h
View file

@ -866,14 +866,17 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_rms_norm( GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_rms_norm_inplace( GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a,
float eps);
// a - x // a - x
// b - dy // b - dy
// TODO: update with configurable eps
GGML_API struct ggml_tensor * ggml_rms_norm_back( GGML_API struct ggml_tensor * ggml_rms_norm_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,

View file

@ -1396,11 +1396,15 @@ static bool llama_eval_internal(
const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab = hparams.n_vocab;
const int64_t n_embd_gqa = hparams.n_embd_gqa(); const int64_t n_embd_gqa = hparams.n_embd_gqa();
LLAMA_ASSERT(n_embd_head == hparams.n_rot); LLAMA_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base; const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale; const float freq_scale = hparams.rope_freq_scale;
// TODO: read from hparams
const float rms_norm_eps = 1e-6f;
const int n_gpu_layers = model.n_gpu_layers; const int n_gpu_layers = model.n_gpu_layers;
auto & mem_per_token = lctx.mem_per_token; auto & mem_per_token = lctx.mem_per_token;
@ -1479,7 +1483,7 @@ static bool llama_eval_internal(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "rms_norm_0"); ggml_set_name(cur, "rms_norm_0");
@ -1627,7 +1631,7 @@ static bool llama_eval_internal(
{ {
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "rms_norm_1"); ggml_set_name(cur, "rms_norm_1");
@ -1680,7 +1684,7 @@ static bool llama_eval_internal(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
offload_func_nr(cur); offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_2"); ggml_set_name(cur, "rms_norm_2");