sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params * llama : combine repetition, frequency and presence penalties in 1 call * examples : remove embd-input and gptneox-wip * sampling : rename penalty params + reduce size of "prev" vector * sampling : add llama_sampling_print helper * sampling : hide prev behind API and apply #3661 ggml-ci
This commit is contained in:
parent
8cf19d60dc
commit
d1031cf49c
30 changed files with 365 additions and 4502 deletions
|
@ -107,7 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
std::string arg;
|
||||
gpt_params default_params;
|
||||
const std::string arg_prefix = "--";
|
||||
llama_sampling_params & sparams = params.sampling_params;
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
|
@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.repeat_last_n = std::stoi(argv[i]);
|
||||
sparams.penalty_last_n = std::stoi(argv[i]);
|
||||
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
|
||||
} else if (arg == "--repeat-penalty") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.repeat_penalty = std::stof(argv[i]);
|
||||
sparams.penalty_repeat = std::stof(argv[i]);
|
||||
} else if (arg == "--frequency-penalty") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.frequency_penalty = std::stof(argv[i]);
|
||||
sparams.penalty_freq = std::stof(argv[i]);
|
||||
} else if (arg == "--presence-penalty") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
sparams.presence_penalty = std::stof(argv[i]);
|
||||
sparams.penalty_present = std::stof(argv[i]);
|
||||
} else if (arg == "--mirostat") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -572,7 +573,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.grammar = argv[i];
|
||||
sparams.grammar = argv[i];
|
||||
} else if (arg == "--grammar-file") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -587,7 +588,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
std::copy(
|
||||
std::istreambuf_iterator<char>(file),
|
||||
std::istreambuf_iterator<char>(),
|
||||
std::back_inserter(params.grammar)
|
||||
std::back_inserter(sparams.grammar)
|
||||
);
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
// Parse args for logging parameters
|
||||
|
@ -640,7 +641,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
}
|
||||
|
||||
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
const llama_sampling_params & sparams = params.sampling_params;
|
||||
const llama_sampling_params & sparams = params.sparams;
|
||||
|
||||
printf("usage: %s [options]\n", argv[0]);
|
||||
printf("\n");
|
||||
|
@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
||||
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
|
||||
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
|
||||
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
|
||||
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
|
||||
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
|
||||
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
|
||||
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
|
||||
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
|
||||
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
|
||||
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
|
||||
printf(" --mirostat N use Mirostat sampling.\n");
|
||||
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
|
||||
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
|
||||
|
@ -878,7 +879,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
}
|
||||
|
||||
if (params.ignore_eos) {
|
||||
params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
|
||||
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -1123,28 +1124,28 @@ std::string get_sortable_timestamp() {
|
|||
|
||||
void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
|
||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
||||
const llama_sampling_params & sparams = params.sampling_params;
|
||||
const llama_sampling_params & sparams = params.sparams;
|
||||
|
||||
fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
|
||||
fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
|
||||
fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false");
|
||||
fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false");
|
||||
|
||||
#ifdef NDEBUG
|
||||
fprintf(stream, "debug: false\n");
|
||||
|
@ -1178,8 +1179,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
||||
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
|
||||
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
|
||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
||||
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
|
||||
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
||||
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
||||
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
|
||||
|
@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
|
||||
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
|
||||
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
|
||||
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
|
||||
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
|
||||
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
|
||||
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
|
||||
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
|
||||
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
|
||||
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
|
||||
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
|
||||
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
|
||||
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
|
||||
|
||||
fprintf(stream, "reverse_prompt:\n");
|
||||
for (std::string ap : params.antiprompt) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue