llama : introduce llama_sampling_params (wip)
ggml-ci
This commit is contained in:
parent
ae9d3f68e9
commit
6b7103cccd
9 changed files with 85 additions and 35 deletions
|
@ -226,7 +226,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||||
bool invalid_param = false;
|
bool invalid_param = false;
|
||||||
std::string arg;
|
std::string arg;
|
||||||
const std::string arg_prefix = "--";
|
const std::string arg_prefix = "--";
|
||||||
llama_sampling_params & sparams = params.sparams;
|
auto & sparams = params.sparams;
|
||||||
|
|
||||||
for (int i = 1; i < argc; i++) {
|
for (int i = 1; i < argc; i++) {
|
||||||
arg = argv[i];
|
arg = argv[i];
|
||||||
|
@ -294,7 +294,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
|
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
|
||||||
const char split_delim = ',';
|
const char split_delim = ',';
|
||||||
|
|
||||||
llama_sampling_params & sparams = params.sparams;
|
auto & sparams = params.sparams;
|
||||||
|
|
||||||
if (arg == "-s" || arg == "--seed") {
|
if (arg == "-s" || arg == "--seed") {
|
||||||
CHECK_ARG
|
CHECK_ARG
|
||||||
|
@ -1375,7 +1375,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
const llama_sampling_params & sparams = params.sparams;
|
const auto & sparams = params.sparams;
|
||||||
|
|
||||||
std::string sampler_type_chars;
|
std::string sampler_type_chars;
|
||||||
std::string sampler_type_names;
|
std::string sampler_type_names;
|
||||||
|
@ -3116,7 +3116,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
|
||||||
|
|
||||||
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
|
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
|
||||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
||||||
const llama_sampling_params & sparams = params.sparams;
|
const auto & sparams = params.sparams;
|
||||||
|
|
||||||
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
|
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
|
||||||
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
|
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
|
||||||
|
|
|
@ -108,8 +108,7 @@ struct gpt_params {
|
||||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||||
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
||||||
|
|
||||||
// // sampling parameters
|
struct gpt_sampling_params sparams;
|
||||||
struct llama_sampling_params sparams;
|
|
||||||
|
|
||||||
std::string model = ""; // model path
|
std::string model = ""; // model path
|
||||||
std::string model_draft = ""; // draft model for speculative decoding
|
std::string model_draft = ""; // draft model for speculative decoding
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
|
|
||||||
#include <random>
|
#include "common.h"
|
||||||
|
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) {
|
struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model) {
|
||||||
struct llama_sampling_context * result = new llama_sampling_context();
|
struct llama_sampling_context * result = new llama_sampling_context();
|
||||||
|
|
||||||
result->params = params;
|
result->params = params;
|
||||||
|
@ -58,7 +58,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llama_sampling_print(const llama_sampling_params & params) {
|
std::string llama_sampling_print(const gpt_sampling_params & params) {
|
||||||
char result[1024];
|
char result[1024];
|
||||||
|
|
||||||
snprintf(result, sizeof(result),
|
snprintf(result, sizeof(result),
|
||||||
|
@ -72,7 +72,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
std::string llama_sampling_order_print(const gpt_sampling_params & params) {
|
||||||
std::string result = "CFG -> Penalties ";
|
std::string result = "CFG -> Penalties ";
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
for (auto sampler_type : params.samplers_sequence) {
|
for (auto sampler_type : params.samplers_sequence) {
|
||||||
|
@ -176,7 +176,7 @@ static void sampler_queue(
|
||||||
size_t min_keep) {
|
size_t min_keep) {
|
||||||
llama_sampling * smpl = ctx_sampling->smpl;
|
llama_sampling * smpl = ctx_sampling->smpl;
|
||||||
|
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const gpt_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const float dynatemp_range = params.dynatemp_range;
|
const float dynatemp_range = params.dynatemp_range;
|
||||||
|
@ -217,7 +217,7 @@ static llama_token llama_sampling_sample_impl(
|
||||||
bool is_resampling) {
|
bool is_resampling) {
|
||||||
llama_sampling * smpl = ctx_sampling->smpl;
|
llama_sampling * smpl = ctx_sampling->smpl;
|
||||||
|
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const gpt_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const int mirostat = params.mirostat;
|
const int mirostat = params.mirostat;
|
||||||
|
@ -308,7 +308,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
||||||
std::vector<float> * original_logits) {
|
std::vector<float> * original_logits) {
|
||||||
llama_sampling * smpl = ctx_sampling->smpl;
|
llama_sampling * smpl = ctx_sampling->smpl;
|
||||||
|
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const gpt_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -18,7 +17,7 @@ enum class llama_sampler_type : char {
|
||||||
};
|
};
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct llama_sampling_params {
|
typedef struct gpt_sampling_params {
|
||||||
int32_t n_prev = 64; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||||
|
@ -60,13 +59,13 @@ typedef struct llama_sampling_params {
|
||||||
|
|
||||||
std::vector<llama_token> penalty_prompt_tokens;
|
std::vector<llama_token> penalty_prompt_tokens;
|
||||||
bool use_penalty_prompt_tokens = false;
|
bool use_penalty_prompt_tokens = false;
|
||||||
} llama_sampling_params;
|
} gpt_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
// TODO: move to llama.h
|
// TODO: move to llama.h
|
||||||
struct llama_sampling_context {
|
struct llama_sampling_context {
|
||||||
// parameters that will be used for sampling
|
// parameters that will be used for sampling
|
||||||
llama_sampling_params params;
|
gpt_sampling_params params;
|
||||||
|
|
||||||
// mirostat sampler state
|
// mirostat sampler state
|
||||||
float mirostat_mu;
|
float mirostat_mu;
|
||||||
|
@ -80,10 +79,8 @@ struct llama_sampling_context {
|
||||||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||||
};
|
};
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
|
|
||||||
// Create a new sampling context instance.
|
// Create a new sampling context instance.
|
||||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model);
|
struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model);
|
||||||
|
|
||||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||||
|
|
||||||
|
@ -102,10 +99,10 @@ llama_token llama_sampling_last(llama_sampling_context * ctx);
|
||||||
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
|
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
|
||||||
|
|
||||||
// Print sampling parameters into a string
|
// Print sampling parameters into a string
|
||||||
std::string llama_sampling_print(const llama_sampling_params & params);
|
std::string llama_sampling_print(const gpt_sampling_params & params);
|
||||||
|
|
||||||
// Print sampling order into a string
|
// Print sampling order into a string
|
||||||
std::string llama_sampling_order_print(const llama_sampling_params & params);
|
std::string llama_sampling_order_print(const gpt_sampling_params & params);
|
||||||
|
|
||||||
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
|
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,6 @@ static void sigint_handler(int signo) {
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
llama_sampling_params & sparams = params.sparams;
|
|
||||||
g_params = ¶ms;
|
g_params = ¶ms;
|
||||||
|
|
||||||
if (!gpt_params_parse(argc, argv, params)) {
|
if (!gpt_params_parse(argc, argv, params)) {
|
||||||
|
@ -111,6 +110,8 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto & sparams = params.sparams;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("infill", "log"));
|
log_set_target(log_filename_generator("infill", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
|
|
|
@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampling_params & sparams = params.sparams;
|
auto & sparams = params.sparams;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("main", "log"));
|
log_set_target(log_filename_generator("main", "log"));
|
||||||
|
|
|
@ -170,11 +170,13 @@ struct server_slot {
|
||||||
std::string stopping_word;
|
std::string stopping_word;
|
||||||
|
|
||||||
// sampling
|
// sampling
|
||||||
llama_token sampled;
|
|
||||||
struct llama_sampling_params sparams;
|
|
||||||
llama_sampling_context * ctx_sampling = nullptr;
|
|
||||||
json json_schema;
|
json json_schema;
|
||||||
|
|
||||||
|
struct gpt_sampling_params sparams;
|
||||||
|
|
||||||
|
llama_token sampled;
|
||||||
|
llama_sampling_context * ctx_sampling = nullptr;
|
||||||
|
|
||||||
int32_t ga_i = 0; // group-attention state
|
int32_t ga_i = 0; // group-attention state
|
||||||
int32_t ga_n = 1; // group-attention factor
|
int32_t ga_n = 1; // group-attention factor
|
||||||
int32_t ga_w = 512; // group-attention width
|
int32_t ga_w = 512; // group-attention width
|
||||||
|
@ -893,8 +895,8 @@ struct server_context {
|
||||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||||
slot_params default_params;
|
slot_params default_params;
|
||||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
||||||
llama_sampling_params default_sparams = params.sparams;
|
auto default_sparams = params.sparams;
|
||||||
auto & data = task.data;
|
const auto & data = task.data;
|
||||||
|
|
||||||
if (data.count("__oaicompat") != 0) {
|
if (data.count("__oaicompat") != 0) {
|
||||||
slot.oaicompat = true;
|
slot.oaicompat = true;
|
||||||
|
@ -933,7 +935,8 @@ struct server_context {
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||||
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
} else if (data.contains("json_schema") && !data.contains("grammar")) {
|
}
|
||||||
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||||
try {
|
try {
|
||||||
auto schema = json_value(data, "json_schema", json::object());
|
auto schema = json_value(data, "json_schema", json::object());
|
||||||
slot.sparams.grammar = json_schema_to_grammar(schema);
|
slot.sparams.grammar = json_schema_to_grammar(schema);
|
||||||
|
|
|
@ -56,6 +56,7 @@ extern "C" {
|
||||||
// struct llama_vocab; // TODO: add in the future
|
// struct llama_vocab; // TODO: add in the future
|
||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
|
struct llama_sampling;
|
||||||
|
|
||||||
typedef int32_t llama_pos;
|
typedef int32_t llama_pos;
|
||||||
typedef int32_t llama_token;
|
typedef int32_t llama_token;
|
||||||
|
@ -355,8 +356,29 @@ extern "C" {
|
||||||
void * kv_overrides; // pointer to vector containing overrides
|
void * kv_overrides; // pointer to vector containing overrides
|
||||||
} llama_model_quantize_params;
|
} llama_model_quantize_params;
|
||||||
|
|
||||||
// sampling types
|
// parameters for sampling the logits
|
||||||
struct llama_sampling;
|
typedef struct llama_sampling_params {
|
||||||
|
uint32_t seed; // the seed used to initialize llama_sampling_context
|
||||||
|
int32_t n_prev; // number of previous tokens to remember
|
||||||
|
int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
|
int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||||
|
int32_t top_k; // <= 0 to use vocab size
|
||||||
|
float top_p; // 1.0 = disabled
|
||||||
|
float min_p; // 0.0 = disabled
|
||||||
|
float tfs_z; // 1.0 = disabled
|
||||||
|
float typical_p; // 1.0 = disabled
|
||||||
|
float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||||
|
float dynatemp_range; // 0.0 = disabled
|
||||||
|
float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||||
|
int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
|
float penalty_repeat; // 1.0 = disabled
|
||||||
|
float penalty_freq; // 0.0 = disabled
|
||||||
|
float penalty_present; // 0.0 = disabled
|
||||||
|
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
|
float mirostat_tau; // target entropy
|
||||||
|
float mirostat_eta; // learning rate
|
||||||
|
bool penalize_nl; // consider newlines as a repeatable token
|
||||||
|
} llama_sampling_params;
|
||||||
|
|
||||||
// performance timing information
|
// performance timing information
|
||||||
struct llama_timings {
|
struct llama_timings {
|
||||||
|
@ -385,9 +407,10 @@ extern "C" {
|
||||||
struct llama_lora_adapter;
|
struct llama_lora_adapter;
|
||||||
|
|
||||||
// Helpers for getting default parameters
|
// Helpers for getting default parameters
|
||||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||||
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
||||||
|
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
|
||||||
|
|
||||||
// Initialize the llama + ggml backend
|
// Initialize the llama + ggml backend
|
||||||
// If numa is true, use NUMA optimizations
|
// If numa is true, use NUMA optimizations
|
||||||
|
|
|
@ -16490,6 +16490,33 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_sampling_params llama_sampling_default_params() {
|
||||||
|
struct llama_sampling_params result = {
|
||||||
|
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||||
|
/*.n_prev =*/ 64,
|
||||||
|
/*.n_probs =*/ 0,
|
||||||
|
/*.min_keep =*/ 0,
|
||||||
|
/*.top_k =*/ 40,
|
||||||
|
/*.top_p =*/ 0.95f,
|
||||||
|
/*.min_p =*/ 0.05f,
|
||||||
|
/*.tfs_z =*/ 1.00f,
|
||||||
|
/*.typical_p =*/ 1.00f,
|
||||||
|
/*.temp =*/ 0.80f,
|
||||||
|
/*.dynatemp_range =*/ 0.00f,
|
||||||
|
/*.dynatemp_exponent =*/ 1.00f,
|
||||||
|
/*.penalty_last_n =*/ 64,
|
||||||
|
/*.penalty_repeat =*/ 1.00f,
|
||||||
|
/*.penalty_freq =*/ 0.00f,
|
||||||
|
/*.penalty_present =*/ 0.00f,
|
||||||
|
/*.mirostat =*/ 0,
|
||||||
|
/*.mirostat_tau =*/ 5.00f,
|
||||||
|
/*.mirostat_eta =*/ 0.10f,
|
||||||
|
/*.penalize_nl =*/ false,
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
size_t llama_max_devices(void) {
|
size_t llama_max_devices(void) {
|
||||||
#if defined(GGML_USE_RPC)
|
#if defined(GGML_USE_RPC)
|
||||||
return GGML_RPC_MAX_SERVERS;
|
return GGML_RPC_MAX_SERVERS;
|
||||||
|
@ -16731,7 +16758,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
ctx->logits_all = params.logits_all;
|
ctx->logits_all = params.logits_all;
|
||||||
|
|
||||||
// build worst-case graph for encoder if a model contains encoder
|
// build worst-case graph for encoder if a model contains encoder
|
||||||
ctx->is_encoding = llama_model_has_encoder(model);
|
ctx->is_encoding = llama_model_has_encoder(model);
|
||||||
|
|
||||||
uint32_t kv_size = cparams.n_ctx;
|
uint32_t kv_size = cparams.n_ctx;
|
||||||
ggml_type type_k = params.type_k;
|
ggml_type type_k = params.type_k;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue