llama : introduce llama_sampling_params (wip)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-08-10 14:17:44 +03:00
parent ae9d3f68e9
commit 6b7103cccd
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
9 changed files with 85 additions and 35 deletions

View file

@ -226,7 +226,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
bool invalid_param = false; bool invalid_param = false;
std::string arg; std::string arg;
const std::string arg_prefix = "--"; const std::string arg_prefix = "--";
llama_sampling_params & sparams = params.sparams; auto & sparams = params.sparams;
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
arg = argv[i]; arg = argv[i];
@ -294,7 +294,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
const char split_delim = ','; const char split_delim = ',';
llama_sampling_params & sparams = params.sparams; auto & sparams = params.sparams;
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
CHECK_ARG CHECK_ARG
@ -1375,7 +1375,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
#endif #endif
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
const llama_sampling_params & sparams = params.sparams; const auto & sparams = params.sparams;
std::string sampler_type_chars; std::string sampler_type_chars;
std::string sampler_type_names; std::string sampler_type_names;
@ -3116,7 +3116,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx, void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) { const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
const llama_sampling_params & sparams = params.sparams; const auto & sparams = params.sparams;
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);

View file

@ -108,8 +108,7 @@ struct gpt_params {
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
// // sampling parameters struct gpt_sampling_params sparams;
struct llama_sampling_params sparams;
std::string model = ""; // model path std::string model = ""; // model path
std::string model_draft = ""; // draft model for speculative decoding std::string model_draft = ""; // draft model for speculative decoding

View file

@ -1,8 +1,8 @@
#include "sampling.h" #include "sampling.h"
#include <random> #include "common.h"
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) { struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model) {
struct llama_sampling_context * result = new llama_sampling_context(); struct llama_sampling_context * result = new llama_sampling_context();
result->params = params; result->params = params;
@ -58,7 +58,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
return result; return result;
} }
std::string llama_sampling_print(const llama_sampling_params & params) { std::string llama_sampling_print(const gpt_sampling_params & params) {
char result[1024]; char result[1024];
snprintf(result, sizeof(result), snprintf(result, sizeof(result),
@ -72,7 +72,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
return std::string(result); return std::string(result);
} }
std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string llama_sampling_order_print(const gpt_sampling_params & params) {
std::string result = "CFG -> Penalties "; std::string result = "CFG -> Penalties ";
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (auto sampler_type : params.samplers_sequence) { for (auto sampler_type : params.samplers_sequence) {
@ -176,7 +176,7 @@ static void sampler_queue(
size_t min_keep) { size_t min_keep) {
llama_sampling * smpl = ctx_sampling->smpl; llama_sampling * smpl = ctx_sampling->smpl;
const llama_sampling_params & params = ctx_sampling->params; const gpt_sampling_params & params = ctx_sampling->params;
const float temp = params.temp; const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range; const float dynatemp_range = params.dynatemp_range;
@ -217,7 +217,7 @@ static llama_token llama_sampling_sample_impl(
bool is_resampling) { bool is_resampling) {
llama_sampling * smpl = ctx_sampling->smpl; llama_sampling * smpl = ctx_sampling->smpl;
const llama_sampling_params & params = ctx_sampling->params; const gpt_sampling_params & params = ctx_sampling->params;
const float temp = params.temp; const float temp = params.temp;
const int mirostat = params.mirostat; const int mirostat = params.mirostat;
@ -308,7 +308,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
std::vector<float> * original_logits) { std::vector<float> * original_logits) {
llama_sampling * smpl = ctx_sampling->smpl; llama_sampling * smpl = ctx_sampling->smpl;
const llama_sampling_params & params = ctx_sampling->params; const gpt_sampling_params & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

View file

@ -2,7 +2,6 @@
#include "llama.h" #include "llama.h"
#include <random>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
@ -18,7 +17,7 @@ enum class llama_sampler_type : char {
}; };
// sampling parameters // sampling parameters
typedef struct llama_sampling_params { typedef struct gpt_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
@ -60,13 +59,13 @@ typedef struct llama_sampling_params {
std::vector<llama_token> penalty_prompt_tokens; std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false; bool use_penalty_prompt_tokens = false;
} llama_sampling_params; } gpt_sampling_params;
// general sampler context // general sampler context
// TODO: move to llama.h // TODO: move to llama.h
struct llama_sampling_context { struct llama_sampling_context {
// parameters that will be used for sampling // parameters that will be used for sampling
llama_sampling_params params; gpt_sampling_params params;
// mirostat sampler state // mirostat sampler state
float mirostat_mu; float mirostat_mu;
@ -80,10 +79,8 @@ struct llama_sampling_context {
size_t n_valid; // Number of correct top tokens with correct probabilities. size_t n_valid; // Number of correct top tokens with correct probabilities.
}; };
#include "common.h"
// Create a new sampling context instance. // Create a new sampling context instance.
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model); struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model);
void llama_sampling_free(struct llama_sampling_context * ctx); void llama_sampling_free(struct llama_sampling_context * ctx);
@ -102,10 +99,10 @@ llama_token llama_sampling_last(llama_sampling_context * ctx);
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
// Print sampling parameters into a string // Print sampling parameters into a string
std::string llama_sampling_print(const llama_sampling_params & params); std::string llama_sampling_print(const gpt_sampling_params & params);
// Print sampling order into a string // Print sampling order into a string
std::string llama_sampling_order_print(const llama_sampling_params & params); std::string llama_sampling_order_print(const gpt_sampling_params & params);
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type); std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);

View file

@ -103,7 +103,6 @@ static void sigint_handler(int signo) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
llama_sampling_params & sparams = params.sparams;
g_params = &params; g_params = &params;
if (!gpt_params_parse(argc, argv, params)) { if (!gpt_params_parse(argc, argv, params)) {
@ -111,6 +110,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
auto & sparams = params.sparams;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("infill", "log")); log_set_target(log_filename_generator("infill", "log"));
LOG_TEE("Log start\n"); LOG_TEE("Log start\n");

View file

@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
llama_sampling_params & sparams = params.sparams; auto & sparams = params.sparams;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("main", "log")); log_set_target(log_filename_generator("main", "log"));

View file

@ -170,11 +170,13 @@ struct server_slot {
std::string stopping_word; std::string stopping_word;
// sampling // sampling
llama_token sampled;
struct llama_sampling_params sparams;
llama_sampling_context * ctx_sampling = nullptr;
json json_schema; json json_schema;
struct gpt_sampling_params sparams;
llama_token sampled;
llama_sampling_context * ctx_sampling = nullptr;
int32_t ga_i = 0; // group-attention state int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor int32_t ga_n = 1; // group-attention factor
int32_t ga_w = 512; // group-attention width int32_t ga_w = 512; // group-attention width
@ -893,8 +895,8 @@ struct server_context {
bool launch_slot_with_task(server_slot & slot, const server_task & task) { bool launch_slot_with_task(server_slot & slot, const server_task & task) {
slot_params default_params; slot_params default_params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
llama_sampling_params default_sparams = params.sparams; auto default_sparams = params.sparams;
auto & data = task.data; const auto & data = task.data;
if (data.count("__oaicompat") != 0) { if (data.count("__oaicompat") != 0) {
slot.oaicompat = true; slot.oaicompat = true;
@ -933,7 +935,8 @@ struct server_context {
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} else if (data.contains("json_schema") && !data.contains("grammar")) { }
if (data.contains("json_schema") && !data.contains("grammar")) {
try { try {
auto schema = json_value(data, "json_schema", json::object()); auto schema = json_value(data, "json_schema", json::object());
slot.sparams.grammar = json_schema_to_grammar(schema); slot.sparams.grammar = json_schema_to_grammar(schema);

View file

@ -56,6 +56,7 @@ extern "C" {
// struct llama_vocab; // TODO: add in the future // struct llama_vocab; // TODO: add in the future
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct llama_sampling;
typedef int32_t llama_pos; typedef int32_t llama_pos;
typedef int32_t llama_token; typedef int32_t llama_token;
@ -355,8 +356,29 @@ extern "C" {
void * kv_overrides; // pointer to vector containing overrides void * kv_overrides; // pointer to vector containing overrides
} llama_model_quantize_params; } llama_model_quantize_params;
// sampling types // parameters for sampling the logits
struct llama_sampling; typedef struct llama_sampling_params {
uint32_t seed; // the seed used to initialize llama_sampling_context
int32_t n_prev; // number of previous tokens to remember
int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k; // <= 0 to use vocab size
float top_p; // 1.0 = disabled
float min_p; // 0.0 = disabled
float tfs_z; // 1.0 = disabled
float typical_p; // 1.0 = disabled
float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range; // 0.0 = disabled
float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat; // 1.0 = disabled
float penalty_freq; // 0.0 = disabled
float penalty_present; // 0.0 = disabled
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau; // target entropy
float mirostat_eta; // learning rate
bool penalize_nl; // consider newlines as a repeatable token
} llama_sampling_params;
// performance timing information // performance timing information
struct llama_timings { struct llama_timings {
@ -385,9 +407,10 @@ extern "C" {
struct llama_lora_adapter; struct llama_lora_adapter;
// Helpers for getting default parameters // Helpers for getting default parameters
LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
// Initialize the llama + ggml backend // Initialize the llama + ggml backend
// If numa is true, use NUMA optimizations // If numa is true, use NUMA optimizations

View file

@ -16490,6 +16490,33 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
return result; return result;
} }
struct llama_sampling_params llama_sampling_default_params() {
struct llama_sampling_params result = {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_prev =*/ 64,
/*.n_probs =*/ 0,
/*.min_keep =*/ 0,
/*.top_k =*/ 40,
/*.top_p =*/ 0.95f,
/*.min_p =*/ 0.05f,
/*.tfs_z =*/ 1.00f,
/*.typical_p =*/ 1.00f,
/*.temp =*/ 0.80f,
/*.dynatemp_range =*/ 0.00f,
/*.dynatemp_exponent =*/ 1.00f,
/*.penalty_last_n =*/ 64,
/*.penalty_repeat =*/ 1.00f,
/*.penalty_freq =*/ 0.00f,
/*.penalty_present =*/ 0.00f,
/*.mirostat =*/ 0,
/*.mirostat_tau =*/ 5.00f,
/*.mirostat_eta =*/ 0.10f,
/*.penalize_nl =*/ false,
};
return result;
}
size_t llama_max_devices(void) { size_t llama_max_devices(void) {
#if defined(GGML_USE_RPC) #if defined(GGML_USE_RPC)
return GGML_RPC_MAX_SERVERS; return GGML_RPC_MAX_SERVERS;
@ -16731,7 +16758,7 @@ struct llama_context * llama_new_context_with_model(
ctx->logits_all = params.logits_all; ctx->logits_all = params.logits_all;
// build worst-case graph for encoder if a model contains encoder // build worst-case graph for encoder if a model contains encoder
ctx->is_encoding = llama_model_has_encoder(model); ctx->is_encoding = llama_model_has_encoder(model);
uint32_t kv_size = cparams.n_ctx; uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k; ggml_type type_k = params.type_k;