make rms_norm_eps a parameter
This commit is contained in:
parent
5b2b2dc6ae
commit
9fe47c747f
6 changed files with 48 additions and 33 deletions
|
@ -16,6 +16,8 @@
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static const float rms_norm_eps = 1e-6f;
|
||||||
|
|
||||||
struct random_normal_distribution {
|
struct random_normal_distribution {
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
std::normal_distribution<float> rd;
|
std::normal_distribution<float> rd;
|
||||||
|
@ -439,7 +441,7 @@ struct ggml_tensor * forward(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
cur = ggml_mul(ctx0,
|
cur = ggml_mul(ctx0,
|
||||||
|
@ -562,7 +564,7 @@ struct ggml_tensor * forward(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
|
@ -606,7 +608,7 @@ struct ggml_tensor * forward(
|
||||||
{
|
{
|
||||||
|
|
||||||
// inpL shape [n_embd,N,1,1]
|
// inpL shape [n_embd,N,1,1]
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
// inpL shape [n_embd,N,1,1]
|
// inpL shape [n_embd,N,1,1]
|
||||||
|
@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
|
@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
|
@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
|
||||||
{
|
{
|
||||||
|
|
||||||
// inpL shape [n_embd,N*n_batch,1,1]
|
// inpL shape [n_embd,N*n_batch,1,1]
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
|
@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
|
@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
|
@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
||||||
{
|
{
|
||||||
|
|
||||||
// inpL shape [n_embd,N*n_batch,1,1]
|
// inpL shape [n_embd,N*n_batch,1,1]
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
|
@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
|
@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
||||||
{
|
{
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
|
@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
|
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
|
@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
struct my_llama_layer & layer = model->layers[il];
|
struct my_llama_layer & layer = model->layers[il];
|
||||||
// tensors with values necessary for backward pass are in persistent buf(-1)
|
// tensors with values necessary for backward pass are in persistent buf(-1)
|
||||||
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
|
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
|
||||||
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||||
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||||
|
@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
|
||||||
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||||
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||||
|
@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
}
|
}
|
||||||
clr_buf(0);
|
clr_buf(0);
|
||||||
use_buf(0);
|
use_buf(0);
|
||||||
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
|
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
|
||||||
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
|
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
|
||||||
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
|
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
|
||||||
use_buf(-1);
|
use_buf(-1);
|
||||||
|
|
13
ggml-cuda.cu
13
ggml-cuda.cu
|
@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
|
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
const float eps = 1e-6f;
|
|
||||||
|
|
||||||
float tmp = 0.0f; // partial sum for thread in warp
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||||
|
@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
|
||||||
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
|
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
|
||||||
|
@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm(
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t i01_diff = i01_high - i01_low;
|
const int64_t i01_diff = i01_high - i01_low;
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
|
|
|
@ -812,7 +812,8 @@ void ggml_metal_graph_compute(
|
||||||
encoder = [command_buffer computeCommandEncoder];
|
encoder = [command_buffer computeCommandEncoder];
|
||||||
}
|
}
|
||||||
|
|
||||||
const float eps = 1e-6f;
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
const int nth = 512;
|
const int nth = 512;
|
||||||
|
|
||||||
|
|
16
ggml.c
16
ggml.c
|
@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace(
|
||||||
static struct ggml_tensor * ggml_rms_norm_impl(
|
static struct ggml_tensor * ggml_rms_norm_impl(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
float eps,
|
||||||
bool inplace) {
|
bool inplace) {
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
|
@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(
|
||||||
|
|
||||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
// TODO: maybe store epsilon here?
|
ggml_set_op_params(result, &eps, sizeof(eps));
|
||||||
|
|
||||||
result->op = GGML_OP_RMS_NORM;
|
result->op = GGML_OP_RMS_NORM;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rms_norm(
|
struct ggml_tensor * ggml_rms_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a,
|
||||||
return ggml_rms_norm_impl(ctx, a, false);
|
float eps) {
|
||||||
|
return ggml_rms_norm_impl(ctx, a, eps, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rms_norm_inplace(
|
struct ggml_tensor * ggml_rms_norm_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a,
|
||||||
return ggml_rms_norm_impl(ctx, a, true);
|
float eps) {
|
||||||
|
return ggml_rms_norm_impl(ctx, a, eps, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rms_norm_back(
|
struct ggml_tensor * ggml_rms_norm_back(
|
||||||
|
@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32(
|
||||||
|
|
||||||
GGML_TENSOR_UNARY_OP_LOCALS;
|
GGML_TENSOR_UNARY_OP_LOCALS;
|
||||||
|
|
||||||
const float eps = 1e-6f; // TODO: make this a parameter
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
|
7
ggml.h
7
ggml.h
|
@ -866,14 +866,17 @@ extern "C" {
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_rms_norm(
|
GGML_API struct ggml_tensor * ggml_rms_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
|
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
// a - x
|
// a - x
|
||||||
// b - dy
|
// b - dy
|
||||||
|
// TODO: update with configurable eps
|
||||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
10
llama.cpp
10
llama.cpp
|
@ -1396,11 +1396,15 @@ static bool llama_eval_internal(
|
||||||
const int64_t n_vocab = hparams.n_vocab;
|
const int64_t n_vocab = hparams.n_vocab;
|
||||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||||
|
|
||||||
|
|
||||||
LLAMA_ASSERT(n_embd_head == hparams.n_rot);
|
LLAMA_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
const float freq_base = hparams.rope_freq_base;
|
const float freq_base = hparams.rope_freq_base;
|
||||||
const float freq_scale = hparams.rope_freq_scale;
|
const float freq_scale = hparams.rope_freq_scale;
|
||||||
|
|
||||||
|
// TODO: read from hparams
|
||||||
|
const float rms_norm_eps = 1e-6f;
|
||||||
|
|
||||||
const int n_gpu_layers = model.n_gpu_layers;
|
const int n_gpu_layers = model.n_gpu_layers;
|
||||||
|
|
||||||
auto & mem_per_token = lctx.mem_per_token;
|
auto & mem_per_token = lctx.mem_per_token;
|
||||||
|
@ -1479,7 +1483,7 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "rms_norm_0");
|
ggml_set_name(cur, "rms_norm_0");
|
||||||
|
|
||||||
|
@ -1627,7 +1631,7 @@ static bool llama_eval_internal(
|
||||||
{
|
{
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "rms_norm_1");
|
ggml_set_name(cur, "rms_norm_1");
|
||||||
|
|
||||||
|
@ -1680,7 +1684,7 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
offload_func_nr(cur);
|
offload_func_nr(cur);
|
||||||
ggml_set_name(cur, "rms_norm_2");
|
ggml_set_name(cur, "rms_norm_2");
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue