From 6b7103cccd913934b80e2a588735e38dad909382 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 10 Aug 2024 14:17:44 +0300 Subject: [PATCH] llama : introduce llama_sampling_params (wip) ggml-ci --- common/common.cpp | 8 ++++---- common/common.h | 3 +-- common/sampling.cpp | 14 +++++++------- common/sampling.h | 15 ++++++--------- examples/infill/infill.cpp | 3 ++- examples/main/main.cpp | 2 +- examples/server/server.cpp | 15 +++++++++------ include/llama.h | 31 +++++++++++++++++++++++++++---- src/llama.cpp | 29 ++++++++++++++++++++++++++++- 9 files changed, 85 insertions(+), 35 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d115a958b..c56246b8f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -226,7 +226,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; const std::string arg_prefix = "--"; - llama_sampling_params & sparams = params.sparams; + auto & sparams = params.sparams; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -294,7 +294,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { const char split_delim = ','; - llama_sampling_params & sparams = params.sparams; + auto & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { CHECK_ARG @@ -1375,7 +1375,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa #endif void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { - const llama_sampling_params & sparams = params.sparams; + const auto & sparams = params.sparams; std::string sampler_type_chars; std::string sampler_type_names; @@ -3116,7 +3116,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sparams; + const auto & sparams = params.sparams; fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); diff --git a/common/common.h b/common/common.h index f9b08b18f..acb4c95ed 100644 --- a/common/common.h +++ b/common/common.h @@ -108,8 +108,7 @@ struct gpt_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - // // sampling parameters - struct llama_sampling_params sparams; + struct gpt_sampling_params sparams; std::string model = ""; // model path std::string model_draft = ""; // draft model for speculative decoding diff --git a/common/sampling.cpp b/common/sampling.cpp index 575baf747..4d57392c8 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,8 +1,8 @@ #include "sampling.h" -#include +#include "common.h" -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) { +struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model) { struct llama_sampling_context * result = new llama_sampling_context(); result->params = params; @@ -58,7 +58,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama return result; } -std::string llama_sampling_print(const llama_sampling_params & params) { +std::string llama_sampling_print(const gpt_sampling_params & params) { char result[1024]; snprintf(result, sizeof(result), @@ -72,7 +72,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) { return std::string(result); } -std::string llama_sampling_order_print(const llama_sampling_params & params) { +std::string llama_sampling_order_print(const gpt_sampling_params & params) { std::string result = "CFG -> Penalties "; if (params.mirostat == 0) { for (auto sampler_type : params.samplers_sequence) { @@ -176,7 +176,7 @@ static void sampler_queue( size_t min_keep) { llama_sampling * smpl = ctx_sampling->smpl; - const llama_sampling_params & params = ctx_sampling->params; + const gpt_sampling_params & params = ctx_sampling->params; const float temp = params.temp; const float dynatemp_range = params.dynatemp_range; @@ -217,7 +217,7 @@ static llama_token llama_sampling_sample_impl( bool is_resampling) { llama_sampling * smpl = ctx_sampling->smpl; - const llama_sampling_params & params = ctx_sampling->params; + const gpt_sampling_params & params = ctx_sampling->params; const float temp = params.temp; const int mirostat = params.mirostat; @@ -308,7 +308,7 @@ static llama_token_data_array llama_sampling_prepare_impl( std::vector * original_logits) { llama_sampling * smpl = ctx_sampling->smpl; - const llama_sampling_params & params = ctx_sampling->params; + const gpt_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); diff --git a/common/sampling.h b/common/sampling.h index 244db47ba..59158ae85 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -2,7 +2,6 @@ #include "llama.h" -#include #include #include #include @@ -18,7 +17,7 @@ enum class llama_sampler_type : char { }; // sampling parameters -typedef struct llama_sampling_params { +typedef struct gpt_sampling_params { int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens @@ -60,13 +59,13 @@ typedef struct llama_sampling_params { std::vector penalty_prompt_tokens; bool use_penalty_prompt_tokens = false; -} llama_sampling_params; +} gpt_sampling_params; // general sampler context // TODO: move to llama.h struct llama_sampling_context { // parameters that will be used for sampling - llama_sampling_params params; + gpt_sampling_params params; // mirostat sampler state float mirostat_mu; @@ -80,10 +79,8 @@ struct llama_sampling_context { size_t n_valid; // Number of correct top tokens with correct probabilities. }; -#include "common.h" - // Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model); +struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model); void llama_sampling_free(struct llama_sampling_context * ctx); @@ -102,10 +99,10 @@ llama_token llama_sampling_last(llama_sampling_context * ctx); std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); // Print sampling parameters into a string -std::string llama_sampling_print(const llama_sampling_params & params); +std::string llama_sampling_print(const gpt_sampling_params & params); // Print sampling order into a string -std::string llama_sampling_order_print(const llama_sampling_params & params); +std::string llama_sampling_order_print(const gpt_sampling_params & params); std::string llama_sampling_type_to_str(llama_sampler_type sampler_type); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 555f59556..cf1f7661d 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -103,7 +103,6 @@ static void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; - llama_sampling_params & sparams = params.sparams; g_params = ¶ms; if (!gpt_params_parse(argc, argv, params)) { @@ -111,6 +110,8 @@ int main(int argc, char ** argv) { return 1; } + auto & sparams = params.sparams; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("infill", "log")); LOG_TEE("Log start\n"); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d9037fb61..b17268903 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -137,7 +137,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling_params & sparams = params.sparams; + auto & sparams = params.sparams; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("main", "log")); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5de1887c8..7899ccb75 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -170,11 +170,13 @@ struct server_slot { std::string stopping_word; // sampling - llama_token sampled; - struct llama_sampling_params sparams; - llama_sampling_context * ctx_sampling = nullptr; json json_schema; + struct gpt_sampling_params sparams; + + llama_token sampled; + llama_sampling_context * ctx_sampling = nullptr; + int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width @@ -893,8 +895,8 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, const server_task & task) { slot_params default_params; // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - llama_sampling_params default_sparams = params.sparams; - auto & data = task.data; + auto default_sparams = params.sparams; + const auto & data = task.data; if (data.count("__oaicompat") != 0) { slot.oaicompat = true; @@ -933,7 +935,8 @@ struct server_context { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); return false; - } else if (data.contains("json_schema") && !data.contains("grammar")) { + } + if (data.contains("json_schema") && !data.contains("grammar")) { try { auto schema = json_value(data, "json_schema", json::object()); slot.sparams.grammar = json_schema_to_grammar(schema); diff --git a/include/llama.h b/include/llama.h index 29cc7fddc..c80405f7f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -56,6 +56,7 @@ extern "C" { // struct llama_vocab; // TODO: add in the future struct llama_model; struct llama_context; + struct llama_sampling; typedef int32_t llama_pos; typedef int32_t llama_token; @@ -355,8 +356,29 @@ extern "C" { void * kv_overrides; // pointer to vector containing overrides } llama_model_quantize_params; - // sampling types - struct llama_sampling; + // parameters for sampling the logits + typedef struct llama_sampling_params { + uint32_t seed; // the seed used to initialize llama_sampling_context + int32_t n_prev; // number of previous tokens to remember + int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k; // <= 0 to use vocab size + float top_p; // 1.0 = disabled + float min_p; // 0.0 = disabled + float tfs_z; // 1.0 = disabled + float typical_p; // 1.0 = disabled + float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range; // 0.0 = disabled + float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat; // 1.0 = disabled + float penalty_freq; // 0.0 = disabled + float penalty_present; // 0.0 = disabled + int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau; // target entropy + float mirostat_eta; // learning rate + bool penalize_nl; // consider newlines as a repeatable token + } llama_sampling_params; // performance timing information struct llama_timings { @@ -385,9 +407,10 @@ extern "C" { struct llama_lora_adapter; // Helpers for getting default parameters - LLAMA_API struct llama_model_params llama_model_default_params(void); - LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_model_params llama_model_default_params(void); + LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); + LLAMA_API struct llama_sampling_params llama_sampling_default_params(void); // Initialize the llama + ggml backend // If numa is true, use NUMA optimizations diff --git a/src/llama.cpp b/src/llama.cpp index 1cdc032df..0eb3aea71 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16490,6 +16490,33 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { return result; } +struct llama_sampling_params llama_sampling_default_params() { + struct llama_sampling_params result = { + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_prev =*/ 64, + /*.n_probs =*/ 0, + /*.min_keep =*/ 0, + /*.top_k =*/ 40, + /*.top_p =*/ 0.95f, + /*.min_p =*/ 0.05f, + /*.tfs_z =*/ 1.00f, + /*.typical_p =*/ 1.00f, + /*.temp =*/ 0.80f, + /*.dynatemp_range =*/ 0.00f, + /*.dynatemp_exponent =*/ 1.00f, + /*.penalty_last_n =*/ 64, + /*.penalty_repeat =*/ 1.00f, + /*.penalty_freq =*/ 0.00f, + /*.penalty_present =*/ 0.00f, + /*.mirostat =*/ 0, + /*.mirostat_tau =*/ 5.00f, + /*.mirostat_eta =*/ 0.10f, + /*.penalize_nl =*/ false, + }; + + return result; +} + size_t llama_max_devices(void) { #if defined(GGML_USE_RPC) return GGML_RPC_MAX_SERVERS; @@ -16731,7 +16758,7 @@ struct llama_context * llama_new_context_with_model( ctx->logits_all = params.logits_all; // build worst-case graph for encoder if a model contains encoder - ctx->is_encoding = llama_model_has_encoder(model); + ctx->is_encoding = llama_model_has_encoder(model); uint32_t kv_size = cparams.n_ctx; ggml_type type_k = params.type_k;