sampling : refactor init to use llama_sampling_params

This commit is contained in:
Georgi Gerganov 2023-10-20 14:58:20 +03:00
parent 8cf19d60dc
commit cd1e937821
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
12 changed files with 110 additions and 142 deletions

View file

@ -107,7 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::string arg;
gpt_params default_params;
const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sampling_params;
llama_sampling_params & sparams = params.sparams;
for (int i = 1; i < argc; i++) {
arg = argv[i];
@ -572,7 +572,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
params.grammar = argv[i];
sparams.grammar = argv[i];
} else if (arg == "--grammar-file") {
if (++i >= argc) {
invalid_param = true;
@ -587,7 +587,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.grammar)
std::back_inserter(sparams.grammar)
);
#ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters
@ -640,7 +640,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
}
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sampling_params;
const llama_sampling_params & sparams = params.sparams;
printf("usage: %s [options]\n", argv[0]);
printf("\n");
@ -878,7 +878,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}
if (params.ignore_eos) {
params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY;
params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY;
}
{
@ -1123,28 +1123,28 @@ std::string get_sortable_timestamp() {
void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
const llama_sampling_params & sparams = params.sampling_params;
const llama_sampling_params & sparams = params.sparams;
fprintf(stream, "build_commit: %s\n", BUILD_COMMIT);
fprintf(stream, "build_number: %d\n", BUILD_NUMBER);
fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false");
fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false");
fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false");
fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false");
fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false");
fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");
fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false");
fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false");
#ifdef NDEBUG
fprintf(stream, "debug: false\n");
@ -1179,7 +1179,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);

View file

@ -56,7 +56,7 @@ struct gpt_params {
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
// // sampling parameters
struct llama_sampling_params sampling_params;
struct llama_sampling_params sparams;
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_draft = ""; // draft model for speculative decoding
@ -66,7 +66,6 @@ struct gpt_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files

View file

@ -1,9 +1,9 @@
#include "sampling.h"
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
result->params = params.sampling_params;
result->params = params;
result->grammar = nullptr;
// if there is a grammar, parse it
@ -23,7 +23,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_params & pa
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
}
result->prev.resize(params.n_ctx);
result->prev.resize(params.n_prev);
return result;
}

View file

@ -22,6 +22,7 @@ typedef struct llama_sampling_params {
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
int32_t n_prev = 256; // number of previous tokens to remember
bool penalize_nl = true; // consider newlines as a repeatable token
@ -34,6 +35,8 @@ typedef struct llama_sampling_params {
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
} llama_sampling_params;
// general sampler context
@ -58,7 +61,7 @@ struct llama_sampling_context {
#include "common.h"
// Create a new sampling context instance.
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
void llama_sampling_free(struct llama_sampling_context * ctx);