llama : fix Windows build + fix norm_rms_eps key

This commit is contained in:
Georgi Gerganov 2023-08-16 13:09:43 +03:00
parent 31fb56e1d3
commit c1fe0aba72
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -63,24 +63,25 @@
#include <stdio.h> // for _fseeki64 #include <stdio.h> // for _fseeki64
#endif #endif
#include <array>
#include <ctime>
#include <cinttypes>
#include <fstream>
#include <random>
#include <map>
#include <unordered_map>
#include <queue>
#include <cassert>
#include <cstring>
#include <climits>
#include <memory>
#include <algorithm> #include <algorithm>
#include <array>
#include <cassert>
#include <cinttypes>
#include <climits>
#include <cstdarg>
#include <cstring>
#include <ctime>
#include <fstream>
#include <initializer_list> #include <initializer_list>
#include <thread> #include <map>
#include <memory>
#include <mutex> #include <mutex>
#include <sstream>
#include <numeric> #include <numeric>
#include <queue>
#include <random>
#include <sstream>
#include <thread>
#include <unordered_map>
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
@ -136,11 +137,12 @@ __attribute__((format(printf, 1, 2)))
#endif #endif
#endif #endif
static std::string format(const char * fmt, ...) { static std::string format(const char * fmt, ...) {
va_list ap, ap2; va_list ap;
va_list ap2;
va_start(ap, fmt); va_start(ap, fmt);
va_copy(ap2, ap); va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap); int size = vsnprintf(NULL, 0, fmt, ap);
GGML_ASSERT(size >= 0 && size < INT_MAX); GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
std::vector<char> buf(size + 1); std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
GGML_ASSERT(size2 == size); GGML_ASSERT(size2 == size);
@ -668,7 +670,7 @@ struct llama_hparams {
uint32_t n_rot = 64; uint32_t n_rot = 64;
uint32_t n_ff = 11008; uint32_t n_ff = 11008;
float f_rms_norm_eps = 1e-5; float f_norm_rms_eps = 1e-5;
float rope_freq_base = 10000.0f; float rope_freq_base = 10000.0f;
float rope_freq_scale = 1.0f; float rope_freq_scale = 1.0f;
@ -1279,7 +1281,7 @@ static void llama_model_load_internal(
hparams.n_head = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.attention.head_count")); hparams.n_head = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.attention.head_count"));
hparams.n_layer = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.block_count")); hparams.n_layer = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.block_count"));
hparams.n_rot = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.rope.dimension_count")); hparams.n_rot = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.rope.dimension_count"));
hparams.f_rms_norm_eps = gguf_get_val_f32(ctx, gguf_find_key(ctx, "llama.rms_norm_epsilon")); hparams.f_norm_rms_eps = gguf_get_val_f32(ctx, gguf_find_key(ctx, "llama.attention.layer_norm_rms_epsilon"));
// n_head_kv default to n_head // n_head_kv default to n_head
hparams.n_head_kv = hparams.n_head; hparams.n_head_kv = hparams.n_head;
@ -1360,7 +1362,7 @@ static void llama_model_load_internal(
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
LLAMA_LOG_INFO("%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
@ -1658,9 +1660,9 @@ static struct ggml_cgraph * llama_build_graph(
GGML_ASSERT(n_embd_head == hparams.n_rot); GGML_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base; const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale; const float freq_scale = hparams.rope_freq_scale;
const float rms_norm_eps = hparams.f_rms_norm_eps; const float norm_rms_eps = hparams.f_norm_rms_eps;
const int n_gpu_layers = model.n_gpu_layers; const int n_gpu_layers = model.n_gpu_layers;
@ -1767,7 +1769,7 @@ static struct ggml_cgraph * llama_build_graph(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "rms_norm_0"); ggml_set_name(cur, "rms_norm_0");
@ -1912,7 +1914,7 @@ static struct ggml_cgraph * llama_build_graph(
{ {
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "rms_norm_1"); ggml_set_name(cur, "rms_norm_1");
@ -1962,7 +1964,7 @@ static struct ggml_cgraph * llama_build_graph(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
offload_func_nr(cur); offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_2"); ggml_set_name(cur, "rms_norm_2");