add rms_norm_eps to command line
This commit is contained in:
parent
9fe47c747f
commit
24e53a1466
4 changed files with 32 additions and 16 deletions
|
@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_gqa = std::stoi(argv[i]);
|
params.n_gqa = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "-eps" || arg == "--rms-norm-eps") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.rms_norm_eps = std::stof(argv[i]);
|
||||||
} else if (arg == "--rope-freq-base") {
|
} else if (arg == "--rope-freq-base") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -519,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||||
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
||||||
|
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %f)\n", params.rms_norm_eps);
|
||||||
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
|
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
|
||||||
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
|
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
|
||||||
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
|
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
|
||||||
|
@ -615,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
lparams.n_ctx = params.n_ctx;
|
lparams.n_ctx = params.n_ctx;
|
||||||
lparams.n_batch = params.n_batch;
|
lparams.n_batch = params.n_batch;
|
||||||
lparams.n_gqa = params.n_gqa;
|
lparams.n_gqa = params.n_gqa;
|
||||||
|
lparams.rms_norm_eps = params.rms_norm_eps;
|
||||||
lparams.n_gpu_layers = params.n_gpu_layers;
|
lparams.n_gpu_layers = params.n_gpu_layers;
|
||||||
lparams.main_gpu = params.main_gpu;
|
lparams.main_gpu = params.main_gpu;
|
||||||
lparams.tensor_split = params.tensor_split;
|
lparams.tensor_split = params.tensor_split;
|
||||||
|
|
|
@ -22,18 +22,19 @@
|
||||||
int32_t get_num_physical_cores();
|
int32_t get_num_physical_cores();
|
||||||
|
|
||||||
struct gpt_params {
|
struct gpt_params {
|
||||||
uint32_t seed = -1; // RNG seed
|
uint32_t seed = -1; // RNG seed
|
||||||
int32_t n_threads = get_num_physical_cores();
|
int32_t n_threads = get_num_physical_cores();
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
int32_t n_ctx = 512; // context size
|
int32_t n_ctx = 512; // context size
|
||||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
|
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||||
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
|
float rms_norm_eps = 1e-6; // rms norm epsilon
|
||||||
float rope_freq_base = 10000.0f; // RoPE base frequency
|
float rope_freq_base = 10000.0f; // RoPE base frequency
|
||||||
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
|
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
|
||||||
|
|
||||||
|
|
16
llama.cpp
16
llama.cpp
|
@ -186,6 +186,7 @@ struct llama_hparams {
|
||||||
// LLaMAv2
|
// LLaMAv2
|
||||||
// TODO: load from model data hparams
|
// TODO: load from model data hparams
|
||||||
float f_ffn_mult = 1.0f;
|
float f_ffn_mult = 1.0f;
|
||||||
|
float f_rms_norm_eps = 1e-6f;
|
||||||
|
|
||||||
float rope_freq_base = 10000.0f;
|
float rope_freq_base = 10000.0f;
|
||||||
float rope_freq_scale = 1.0f;
|
float rope_freq_scale = 1.0f;
|
||||||
|
@ -869,6 +870,7 @@ struct llama_context_params llama_context_default_params() {
|
||||||
/*.n_ctx =*/ 512,
|
/*.n_ctx =*/ 512,
|
||||||
/*.n_batch =*/ 512,
|
/*.n_batch =*/ 512,
|
||||||
/*.n_gqa =*/ 1,
|
/*.n_gqa =*/ 1,
|
||||||
|
/*.rms_norm_eps =*/ 1e-6f,
|
||||||
/*.gpu_layers =*/ 0,
|
/*.gpu_layers =*/ 0,
|
||||||
/*.main_gpu =*/ 0,
|
/*.main_gpu =*/ 0,
|
||||||
/*.tensor_split =*/ nullptr,
|
/*.tensor_split =*/ nullptr,
|
||||||
|
@ -1000,6 +1002,7 @@ static void llama_model_load_internal(
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
int n_batch,
|
int n_batch,
|
||||||
int n_gqa,
|
int n_gqa,
|
||||||
|
float rms_norm_eps,
|
||||||
int n_gpu_layers,
|
int n_gpu_layers,
|
||||||
int main_gpu,
|
int main_gpu,
|
||||||
const float * tensor_split,
|
const float * tensor_split,
|
||||||
|
@ -1024,6 +1027,9 @@ static void llama_model_load_internal(
|
||||||
|
|
||||||
auto & hparams = model.hparams;
|
auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
// TODO: read from file
|
||||||
|
hparams.f_rms_norm_eps = rms_norm_eps;
|
||||||
|
|
||||||
{
|
{
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 26: model.type = e_model::MODEL_3B; break;
|
case 26: model.type = e_model::MODEL_3B; break;
|
||||||
|
@ -1072,6 +1078,7 @@ static void llama_model_load_internal(
|
||||||
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
|
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
|
||||||
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
|
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
|
||||||
fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa());
|
fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa());
|
||||||
|
fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps);
|
||||||
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
|
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
|
||||||
fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
|
fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
|
||||||
fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
|
fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
|
||||||
|
@ -1330,6 +1337,7 @@ static bool llama_model_load(
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
int n_batch,
|
int n_batch,
|
||||||
int n_gqa,
|
int n_gqa,
|
||||||
|
float rms_norm_eps,
|
||||||
int n_gpu_layers,
|
int n_gpu_layers,
|
||||||
int main_gpu,
|
int main_gpu,
|
||||||
const float * tensor_split,
|
const float * tensor_split,
|
||||||
|
@ -1343,7 +1351,7 @@ static bool llama_model_load(
|
||||||
llama_progress_callback progress_callback,
|
llama_progress_callback progress_callback,
|
||||||
void *progress_callback_user_data) {
|
void *progress_callback_user_data) {
|
||||||
try {
|
try {
|
||||||
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
|
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
|
||||||
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
|
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
|
@ -1401,9 +1409,7 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
const float freq_base = hparams.rope_freq_base;
|
const float freq_base = hparams.rope_freq_base;
|
||||||
const float freq_scale = hparams.rope_freq_scale;
|
const float freq_scale = hparams.rope_freq_scale;
|
||||||
|
const float rms_norm_eps = hparams.f_rms_norm_eps;
|
||||||
// TODO: read from hparams
|
|
||||||
const float rms_norm_eps = 1e-6f;
|
|
||||||
|
|
||||||
const int n_gpu_layers = model.n_gpu_layers;
|
const int n_gpu_layers = model.n_gpu_layers;
|
||||||
|
|
||||||
|
@ -3088,7 +3094,7 @@ struct llama_model * llama_load_model_from_file(
|
||||||
|
|
||||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers,
|
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
|
||||||
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
|
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
|
||||||
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
|
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
|
||||||
params.progress_callback_user_data)) {
|
params.progress_callback_user_data)) {
|
||||||
|
|
1
llama.h
1
llama.h
|
@ -87,6 +87,7 @@ extern "C" {
|
||||||
int32_t n_ctx; // text context
|
int32_t n_ctx; // text context
|
||||||
int32_t n_batch; // prompt processing batch size
|
int32_t n_batch; // prompt processing batch size
|
||||||
int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
|
int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
|
||||||
|
float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
|
||||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue