diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 449b4e9ec..4bbf6b782 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -16,6 +16,8 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +static const float rms_norm_eps = 1e-6f; + struct random_normal_distribution { std::mt19937 gen; std::normal_distribution rd; @@ -439,7 +441,7 @@ struct ggml_tensor * forward( // norm { // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); // cur = attention_norm*cur cur = ggml_mul(ctx0, @@ -562,7 +564,7 @@ struct ggml_tensor * forward( // norm { // cur shape [n_embd,N,1,1] - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); // cur = ffn_norm*cur // cur shape [n_embd,N,1,1] @@ -606,7 +608,7 @@ struct ggml_tensor * forward( { // inpL shape [n_embd,N,1,1] - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); // inpL = norm*inpL // inpL shape [n_embd,N,1,1] @@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = attention_norm*cur @@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = ffn_norm*cur @@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch( { // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(inpL, n_embd, N*n_batch); // inpL = norm*inpL @@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = attention_norm*cur @@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache( // norm { // cur shape [n_embd,N*n_batch,1,1] - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = ffn_norm*cur @@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache( { // inpL shape [n_embd,N*n_batch,1,1] - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(inpL, n_embd, N*n_batch); // inpL = norm*inpL @@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( // norm { - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = attention_norm*cur @@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( { // norm { - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); assert_shape_2d(cur, n_embd, N*n_batch); // cur = ffn_norm*cur @@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( // norm { - inpL = ggml_rms_norm(ctx0, inpL); + inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps); assert_shape_2d(inpL, n_embd, N*n_batch); // inpL = norm*inpL @@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( struct my_llama_layer & layer = model->layers[il]; // tensors with values necessary for backward pass are in persistent buf(-1) // other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed. - use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch); @@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch); - use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch); + use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch); @@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( } clr_buf(0); use_buf(0); - struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch); + struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch); struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch); struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch); use_buf(-1); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b8c98354d..87a166061 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { } } -static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { +static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const float eps = 1e-6f; - float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += WARP_SIZE) { @@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i norm_f32<<>>(x, dst, ncols); } -static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols); + rms_norm_f32<<>>(x, dst, ncols, eps); } static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) { @@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + // compute - rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); + rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main); (void) src1; (void) dst; diff --git a/ggml-metal.m b/ggml-metal.m index 1fd6e857f..c1db3d165 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -812,7 +812,8 @@ void ggml_metal_graph_compute( encoder = [command_buffer computeCommandEncoder]; } - const float eps = 1e-6f; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); const int nth = 512; diff --git a/ggml.c b/ggml.c index 960b80577..11226c834 100644 --- a/ggml.c +++ b/ggml.c @@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace( static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, + float eps, bool inplace) { bool is_node = false; @@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - // TODO: maybe store epsilon here? + ggml_set_op_params(result, &eps, sizeof(eps)); result->op = GGML_OP_RMS_NORM; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl( struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_rms_norm_impl(ctx, a, false); + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, false); } struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a) { - return ggml_rms_norm_impl(ctx, a, true); + struct ggml_tensor * a, + float eps) { + return ggml_rms_norm_impl(ctx, a, eps, true); } struct ggml_tensor * ggml_rms_norm_back( @@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32( GGML_TENSOR_UNARY_OP_LOCALS; - const float eps = 1e-6f; // TODO: make this a parameter + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { diff --git a/ggml.h b/ggml.h index de44fba9e..1870b62e8 100644 --- a/ggml.h +++ b/ggml.h @@ -866,14 +866,17 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rms_norm( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); GGML_API struct ggml_tensor * ggml_rms_norm_inplace( struct ggml_context * ctx, - struct ggml_tensor * a); + struct ggml_tensor * a, + float eps); // a - x // b - dy + // TODO: update with configurable eps GGML_API struct ggml_tensor * ggml_rms_norm_back( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/llama.cpp b/llama.cpp index 0288f7e1f..d95651e62 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1396,11 +1396,15 @@ static bool llama_eval_internal( const int64_t n_vocab = hparams.n_vocab; const int64_t n_embd_gqa = hparams.n_embd_gqa(); + LLAMA_ASSERT(n_embd_head == hparams.n_rot); const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; + // TODO: read from hparams + const float rms_norm_eps = 1e-6f; + const int n_gpu_layers = model.n_gpu_layers; auto & mem_per_token = lctx.mem_per_token; @@ -1479,7 +1483,7 @@ static bool llama_eval_internal( // norm { - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); offload_func(cur); ggml_set_name(cur, "rms_norm_0"); @@ -1627,7 +1631,7 @@ static bool llama_eval_internal( { // norm { - cur = ggml_rms_norm(ctx0, inpFF); + cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); offload_func(cur); ggml_set_name(cur, "rms_norm_1"); @@ -1680,7 +1684,7 @@ static bool llama_eval_internal( // norm { - cur = ggml_rms_norm(ctx0, inpL); + cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); offload_func_nr(cur); ggml_set_name(cur, "rms_norm_2");