llama : introduce llama_sampling_params (wip)
ggml-ci
This commit is contained in:
parent
ae9d3f68e9
commit
6b7103cccd
9 changed files with 85 additions and 35 deletions
|
@ -226,7 +226,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
bool invalid_param = false;
|
||||
std::string arg;
|
||||
const std::string arg_prefix = "--";
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
auto & sparams = params.sparams;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
|
@ -294,7 +294,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
|
||||
const char split_delim = ',';
|
||||
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
auto & sparams = params.sparams;
|
||||
|
||||
if (arg == "-s" || arg == "--seed") {
|
||||
CHECK_ARG
|
||||
|
@ -1375,7 +1375,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||
#endif
|
||||
|
||||
void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
const llama_sampling_params & sparams = params.sparams;
|
||||
const auto & sparams = params.sparams;
|
||||
|
||||
std::string sampler_type_chars;
|
||||
std::string sampler_type_names;
|
||||
|
@ -3116,7 +3116,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
|
|||
|
||||
void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx,
|
||||
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) {
|
||||
const llama_sampling_params & sparams = params.sparams;
|
||||
const auto & sparams = params.sparams;
|
||||
|
||||
fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
|
||||
fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
|
||||
|
|
|
@ -108,8 +108,7 @@ struct gpt_params {
|
|||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
|
||||
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
|
||||
|
||||
// // sampling parameters
|
||||
struct llama_sampling_params sparams;
|
||||
struct gpt_sampling_params sparams;
|
||||
|
||||
std::string model = ""; // model path
|
||||
std::string model_draft = ""; // draft model for speculative decoding
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#include "sampling.h"
|
||||
|
||||
#include <random>
|
||||
#include "common.h"
|
||||
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) {
|
||||
struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
|
||||
result->params = params;
|
||||
|
@ -58,7 +58,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
|
|||
return result;
|
||||
}
|
||||
|
||||
std::string llama_sampling_print(const llama_sampling_params & params) {
|
||||
std::string llama_sampling_print(const gpt_sampling_params & params) {
|
||||
char result[1024];
|
||||
|
||||
snprintf(result, sizeof(result),
|
||||
|
@ -72,7 +72,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
|
|||
return std::string(result);
|
||||
}
|
||||
|
||||
std::string llama_sampling_order_print(const llama_sampling_params & params) {
|
||||
std::string llama_sampling_order_print(const gpt_sampling_params & params) {
|
||||
std::string result = "CFG -> Penalties ";
|
||||
if (params.mirostat == 0) {
|
||||
for (auto sampler_type : params.samplers_sequence) {
|
||||
|
@ -176,7 +176,7 @@ static void sampler_queue(
|
|||
size_t min_keep) {
|
||||
llama_sampling * smpl = ctx_sampling->smpl;
|
||||
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
const gpt_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const float temp = params.temp;
|
||||
const float dynatemp_range = params.dynatemp_range;
|
||||
|
@ -217,7 +217,7 @@ static llama_token llama_sampling_sample_impl(
|
|||
bool is_resampling) {
|
||||
llama_sampling * smpl = ctx_sampling->smpl;
|
||||
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
const gpt_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const float temp = params.temp;
|
||||
const int mirostat = params.mirostat;
|
||||
|
@ -308,7 +308,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
|
|||
std::vector<float> * original_logits) {
|
||||
llama_sampling * smpl = ctx_sampling->smpl;
|
||||
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
const gpt_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#include "llama.h"
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
@ -18,7 +17,7 @@ enum class llama_sampler_type : char {
|
|||
};
|
||||
|
||||
// sampling parameters
|
||||
typedef struct llama_sampling_params {
|
||||
typedef struct gpt_sampling_params {
|
||||
int32_t n_prev = 64; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
|
@ -60,13 +59,13 @@ typedef struct llama_sampling_params {
|
|||
|
||||
std::vector<llama_token> penalty_prompt_tokens;
|
||||
bool use_penalty_prompt_tokens = false;
|
||||
} llama_sampling_params;
|
||||
} gpt_sampling_params;
|
||||
|
||||
// general sampler context
|
||||
// TODO: move to llama.h
|
||||
struct llama_sampling_context {
|
||||
// parameters that will be used for sampling
|
||||
llama_sampling_params params;
|
||||
gpt_sampling_params params;
|
||||
|
||||
// mirostat sampler state
|
||||
float mirostat_mu;
|
||||
|
@ -80,10 +79,8 @@ struct llama_sampling_context {
|
|||
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||
};
|
||||
|
||||
#include "common.h"
|
||||
|
||||
// Create a new sampling context instance.
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model);
|
||||
struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_params & params, const struct llama_model * model);
|
||||
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||
|
||||
|
@ -102,10 +99,10 @@ llama_token llama_sampling_last(llama_sampling_context * ctx);
|
|||
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
|
||||
|
||||
// Print sampling parameters into a string
|
||||
std::string llama_sampling_print(const llama_sampling_params & params);
|
||||
std::string llama_sampling_print(const gpt_sampling_params & params);
|
||||
|
||||
// Print sampling order into a string
|
||||
std::string llama_sampling_order_print(const llama_sampling_params & params);
|
||||
std::string llama_sampling_order_print(const gpt_sampling_params & params);
|
||||
|
||||
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);
|
||||
|
||||
|
|
|
@ -103,7 +103,6 @@ static void sigint_handler(int signo) {
|
|||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
g_params = ¶ms;
|
||||
|
||||
if (!gpt_params_parse(argc, argv, params)) {
|
||||
|
@ -111,6 +110,8 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
auto & sparams = params.sparams;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("infill", "log"));
|
||||
LOG_TEE("Log start\n");
|
||||
|
|
|
@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
auto & sparams = params.sparams;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("main", "log"));
|
||||
|
|
|
@ -170,11 +170,13 @@ struct server_slot {
|
|||
std::string stopping_word;
|
||||
|
||||
// sampling
|
||||
llama_token sampled;
|
||||
struct llama_sampling_params sparams;
|
||||
llama_sampling_context * ctx_sampling = nullptr;
|
||||
json json_schema;
|
||||
|
||||
struct gpt_sampling_params sparams;
|
||||
|
||||
llama_token sampled;
|
||||
llama_sampling_context * ctx_sampling = nullptr;
|
||||
|
||||
int32_t ga_i = 0; // group-attention state
|
||||
int32_t ga_n = 1; // group-attention factor
|
||||
int32_t ga_w = 512; // group-attention width
|
||||
|
@ -893,8 +895,8 @@ struct server_context {
|
|||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||
slot_params default_params;
|
||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
||||
llama_sampling_params default_sparams = params.sparams;
|
||||
auto & data = task.data;
|
||||
auto default_sparams = params.sparams;
|
||||
const auto & data = task.data;
|
||||
|
||||
if (data.count("__oaicompat") != 0) {
|
||||
slot.oaicompat = true;
|
||||
|
@ -933,7 +935,8 @@ struct server_context {
|
|||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
} else if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
}
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
slot.sparams.grammar = json_schema_to_grammar(schema);
|
||||
|
|
|
@ -56,6 +56,7 @@ extern "C" {
|
|||
// struct llama_vocab; // TODO: add in the future
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
struct llama_sampling;
|
||||
|
||||
typedef int32_t llama_pos;
|
||||
typedef int32_t llama_token;
|
||||
|
@ -355,8 +356,29 @@ extern "C" {
|
|||
void * kv_overrides; // pointer to vector containing overrides
|
||||
} llama_model_quantize_params;
|
||||
|
||||
// sampling types
|
||||
struct llama_sampling;
|
||||
// parameters for sampling the logits
|
||||
typedef struct llama_sampling_params {
|
||||
uint32_t seed; // the seed used to initialize llama_sampling_context
|
||||
int32_t n_prev; // number of previous tokens to remember
|
||||
int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
||||
int32_t top_k; // <= 0 to use vocab size
|
||||
float top_p; // 1.0 = disabled
|
||||
float min_p; // 0.0 = disabled
|
||||
float tfs_z; // 1.0 = disabled
|
||||
float typical_p; // 1.0 = disabled
|
||||
float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
||||
float dynatemp_range; // 0.0 = disabled
|
||||
float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler
|
||||
int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||
float penalty_repeat; // 1.0 = disabled
|
||||
float penalty_freq; // 0.0 = disabled
|
||||
float penalty_present; // 0.0 = disabled
|
||||
int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau; // target entropy
|
||||
float mirostat_eta; // learning rate
|
||||
bool penalize_nl; // consider newlines as a repeatable token
|
||||
} llama_sampling_params;
|
||||
|
||||
// performance timing information
|
||||
struct llama_timings {
|
||||
|
@ -385,9 +407,10 @@ extern "C" {
|
|||
struct llama_lora_adapter;
|
||||
|
||||
// Helpers for getting default parameters
|
||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||
LLAMA_API struct llama_model_params llama_model_default_params(void);
|
||||
LLAMA_API struct llama_context_params llama_context_default_params(void);
|
||||
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
|
||||
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
|
||||
|
||||
// Initialize the llama + ggml backend
|
||||
// If numa is true, use NUMA optimizations
|
||||
|
|
|
@ -16490,6 +16490,33 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
|
|||
return result;
|
||||
}
|
||||
|
||||
struct llama_sampling_params llama_sampling_default_params() {
|
||||
struct llama_sampling_params result = {
|
||||
/*.seed =*/ LLAMA_DEFAULT_SEED,
|
||||
/*.n_prev =*/ 64,
|
||||
/*.n_probs =*/ 0,
|
||||
/*.min_keep =*/ 0,
|
||||
/*.top_k =*/ 40,
|
||||
/*.top_p =*/ 0.95f,
|
||||
/*.min_p =*/ 0.05f,
|
||||
/*.tfs_z =*/ 1.00f,
|
||||
/*.typical_p =*/ 1.00f,
|
||||
/*.temp =*/ 0.80f,
|
||||
/*.dynatemp_range =*/ 0.00f,
|
||||
/*.dynatemp_exponent =*/ 1.00f,
|
||||
/*.penalty_last_n =*/ 64,
|
||||
/*.penalty_repeat =*/ 1.00f,
|
||||
/*.penalty_freq =*/ 0.00f,
|
||||
/*.penalty_present =*/ 0.00f,
|
||||
/*.mirostat =*/ 0,
|
||||
/*.mirostat_tau =*/ 5.00f,
|
||||
/*.mirostat_eta =*/ 0.10f,
|
||||
/*.penalize_nl =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t llama_max_devices(void) {
|
||||
#if defined(GGML_USE_RPC)
|
||||
return GGML_RPC_MAX_SERVERS;
|
||||
|
@ -16731,7 +16758,7 @@ struct llama_context * llama_new_context_with_model(
|
|||
ctx->logits_all = params.logits_all;
|
||||
|
||||
// build worst-case graph for encoder if a model contains encoder
|
||||
ctx->is_encoding = llama_model_has_encoder(model);
|
||||
ctx->is_encoding = llama_model_has_encoder(model);
|
||||
|
||||
uint32_t kv_size = cparams.n_ctx;
|
||||
ggml_type type_k = params.type_k;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue