ggml : pass eps to ggml_norm

This commit is contained in:
Georgi Gerganov 2023-08-23 10:40:58 +03:00
parent d561b7f724
commit e3c52bd990
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 49 additions and 40 deletions

View file

@ -938,7 +938,8 @@ void ggml_metal_graph_compute(
} break; } break;
case GGML_OP_NORM: case GGML_OP_NORM:
{ {
const float eps = 1e-5f; float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int nth = 256; const int nth = 256;

16
ggml.c
View file

@ -5789,6 +5789,7 @@ struct ggml_tensor * ggml_silu_back(
static struct ggml_tensor * ggml_norm_impl( static struct ggml_tensor * ggml_norm_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
float eps,
bool inplace) { bool inplace) {
bool is_node = false; bool is_node = false;
@ -5799,7 +5800,7 @@ static struct ggml_tensor * ggml_norm_impl(
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
// TODO: maybe store epsilon here? ggml_set_op_params(result, &eps, sizeof(eps));
result->op = GGML_OP_NORM; result->op = GGML_OP_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -5810,14 +5811,16 @@ static struct ggml_tensor * ggml_norm_impl(
struct ggml_tensor * ggml_norm( struct ggml_tensor * ggml_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a,
return ggml_norm_impl(ctx, a, false); float eps) {
return ggml_norm_impl(ctx, a, eps, false);
} }
struct ggml_tensor * ggml_norm_inplace( struct ggml_tensor * ggml_norm_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a,
return ggml_norm_impl(ctx, a, true); float eps) {
return ggml_norm_impl(ctx, a, eps, true);
} }
// ggml_rms_norm // ggml_rms_norm
@ -10619,7 +10622,8 @@ static void ggml_compute_forward_norm_f32(
GGML_TENSOR_UNARY_OP_LOCALS; GGML_TENSOR_UNARY_OP_LOCALS;
const float eps = 1e-5f; // TODO: make this a parameter float eps;
memcpy(&eps, dst->op_params, sizeof(float));
// TODO: optimize // TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i03 = 0; i03 < ne03; i03++) {

7
ggml.h
View file

@ -909,14 +909,15 @@ extern "C" {
struct ggml_tensor * b); struct ggml_tensor * b);
// normalize along rows // normalize along rows
// TODO: eps is hardcoded to 1e-5 for now
GGML_API struct ggml_tensor * ggml_norm( GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_norm_inplace( GGML_API struct ggml_tensor * ggml_norm_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_rms_norm( GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx, struct ggml_context * ctx,

View file

@ -830,6 +830,7 @@ struct llama_hparams {
uint32_t n_rot = 64; uint32_t n_rot = 64;
uint32_t n_ff = 11008; uint32_t n_ff = 11008;
float f_norm_eps = 1e-5;
float f_norm_rms_eps = 1e-5; float f_norm_rms_eps = 1e-5;
float rope_freq_base = 10000.0f; float rope_freq_base = 10000.0f;
@ -1557,6 +1558,7 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_FALCON: case LLM_ARCH_FALCON:
{ {
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
} break; } break;
default: (void)0; default: (void)0;
}; };
@ -1672,28 +1674,29 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
const auto & vocab = model.vocab; const auto & vocab = model.vocab;
// hparams // hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver)); LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str()); LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str()); LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9); LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml.n_elements*1e-9);
// general kv // general kv
LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str()); LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, model.name.c_str());
// special tokens // special tokens
if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
@ -1899,8 +1902,7 @@ static void llm_load_tensors(
mmapped_size - vram_weights; // weights in VRAM not in memory mmapped_size - vram_weights; // weights in VRAM not in memory
// this is the memory required by one llama_state // this is the memory required by one llama_state
const size_t mem_required_state = const size_t mem_required_state = scale*hparams.kv_size();
scale*hparams.kv_size();
LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
@ -2383,6 +2385,10 @@ static struct ggml_cgraph * llm_build_falcon(
GGML_ASSERT(n_embd_head == hparams.n_rot); GGML_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
const float norm_eps = hparams.f_norm_eps;
auto & buf_compute = lctx.buf_compute; auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = { struct ggml_init_params params = {
@ -2436,7 +2442,7 @@ static struct ggml_cgraph * llm_build_falcon(
// self-attention // self-attention
{ {
attn_norm = ggml_norm(ctx0, inpL); attn_norm = ggml_norm(ctx0, inpL, norm_eps);
attn_norm = ggml_add(ctx0, attn_norm = ggml_add(ctx0,
ggml_mul(ctx0, ggml_mul(ctx0,
@ -2445,7 +2451,7 @@ static struct ggml_cgraph * llm_build_falcon(
ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm)); ggml_repeat(ctx0, model.layers[il].attn_norm_b, attn_norm));
if (model.layers[il].attn_norm_2) { // Falcon-40B if (model.layers[il].attn_norm_2) { // Falcon-40B
cur = ggml_norm(ctx0, inpL); cur = ggml_norm(ctx0, inpL, norm_eps);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
ggml_mul(ctx0, ggml_mul(ctx0,
@ -2490,8 +2496,8 @@ static struct ggml_cgraph * llm_build_falcon(
wsize * n_embd_head * (n_head + n_head_kv)); wsize * n_embd_head * (n_head + n_head_kv));
// using mode = 2 for neox mode // using mode = 2 for neox mode
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0); Qcur = ggml_rope_custom_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0); Kcur = ggml_rope_custom_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
// store key and value to memory // store key and value to memory
{ {
@ -2522,8 +2528,6 @@ static struct ggml_cgraph * llm_build_falcon(
// K * Q // K * Q
// K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy));
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
@ -2549,7 +2553,6 @@ static struct ggml_cgraph * llm_build_falcon(
n_embd_head, n_head_kv, n_past + N), n_embd_head, n_head_kv, n_past + N),
0, 2, 1, 3); 0, 2, 1, 3);
// V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy)));
V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); V = ggml_cont(ctx0, ggml_transpose(ctx0, V));
// KQV = transpose(V) * KQ_soft_max // KQV = transpose(V) * KQ_soft_max
@ -2589,7 +2592,7 @@ static struct ggml_cgraph * llm_build_falcon(
// norm // norm
{ {
cur = ggml_norm(ctx0, inpL); cur = ggml_norm(ctx0, inpL, norm_eps);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
ggml_mul(ctx0, ggml_mul(ctx0,