integrated mirostat as a launch parameter, works on all models
This commit is contained in:
parent
851f55325a
commit
8a964e76c8
3 changed files with 102 additions and 26 deletions
2
expose.h
2
expose.h
|
@ -33,7 +33,7 @@ struct generation_inputs
|
||||||
const float tfs;
|
const float tfs;
|
||||||
const float rep_pen;
|
const float rep_pen;
|
||||||
const int rep_pen_range;
|
const int rep_pen_range;
|
||||||
const int mirostat;
|
const int mirostat = 0;
|
||||||
const float mirostat_eta;
|
const float mirostat_eta;
|
||||||
const float mirostat_tau;
|
const float mirostat_tau;
|
||||||
const char * stop_sequence[stop_token_max];
|
const char * stop_sequence[stop_token_max];
|
||||||
|
|
|
@ -95,7 +95,62 @@ llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng)
|
llama_token sample_token_mirostat(int n_vocab, llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int m, float * mu)
|
||||||
|
{
|
||||||
|
float N = float(n_vocab);
|
||||||
|
llama_sample_softmax(nullptr, candidates);
|
||||||
|
// Estimate s_hat using the most probable m tokens
|
||||||
|
float s_hat = 0.0;
|
||||||
|
float sum_ti_bi = 0.0;
|
||||||
|
float sum_ti_sq = 0.0;
|
||||||
|
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
|
||||||
|
float t_i = logf(float(i + 2) / float(i + 1));
|
||||||
|
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
|
||||||
|
sum_ti_bi += t_i * b_i;
|
||||||
|
sum_ti_sq += t_i * t_i;
|
||||||
|
}
|
||||||
|
s_hat = sum_ti_bi / sum_ti_sq;
|
||||||
|
// Compute k from the estimated s_hat and target surprise value
|
||||||
|
float epsilon_hat = s_hat - 1;
|
||||||
|
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
|
||||||
|
// Sample the next word X using top-k sampling
|
||||||
|
llama_sample_top_k(nullptr, candidates, int(k));
|
||||||
|
llama_token X = sample_token(candidates, rng); // Compute error as the difference between observed surprise and target surprise value
|
||||||
|
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||||
|
return candidate.id == X;
|
||||||
|
}));
|
||||||
|
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||||
|
float e = observed_surprise - tau;
|
||||||
|
// Update mu using the learning rate and error
|
||||||
|
*mu = *mu - eta * e;
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token sample_token_mirostat_v2(llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float * mu)
|
||||||
|
{
|
||||||
|
llama_sample_softmax(nullptr, candidates);
|
||||||
|
// Truncate the words with surprise values greater than mu
|
||||||
|
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||||
|
return -log2f(candidate.p) > *mu;
|
||||||
|
}));
|
||||||
|
// Normalize the probabilities of the remaining words
|
||||||
|
llama_sample_softmax(nullptr, candidates);
|
||||||
|
// Sample the next word X from the remaining words
|
||||||
|
llama_token X = sample_token(candidates,rng);
|
||||||
|
|
||||||
|
// Compute error as the difference between observed surprise and target surprise value
|
||||||
|
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
|
||||||
|
return candidate.id == X;
|
||||||
|
}));
|
||||||
|
float observed_surprise = -log2f(candidates->data[X_idx].p);
|
||||||
|
float e = observed_surprise - tau;
|
||||||
|
// Update mu using the learning rate and error
|
||||||
|
*mu = *mu - eta * e;
|
||||||
|
return X;
|
||||||
|
}
|
||||||
|
|
||||||
|
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
|
||||||
|
int mirostat, float mirostat_tau, float mirostat_eta)
|
||||||
{
|
{
|
||||||
int id = 0;
|
int id = 0;
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
|
@ -116,10 +171,28 @@ int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range
|
||||||
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||||
// last_n_repeat, alpha_frequency, alpha_presence);
|
// last_n_repeat, alpha_frequency, alpha_presence);
|
||||||
|
|
||||||
if (temp <= 0) {
|
if (temp <= 0)
|
||||||
|
{
|
||||||
// Greedy sampling
|
// Greedy sampling
|
||||||
id = llama_sample_token_greedy(nullptr, &candidates_p);
|
id = llama_sample_token_greedy(nullptr, &candidates_p);
|
||||||
} else {
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (mirostat == 1)
|
||||||
|
{
|
||||||
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
|
const int mirostat_m = 100;
|
||||||
|
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||||
|
id = sample_token_mirostat(n_vocab, &candidates_p, rng, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||||
|
}
|
||||||
|
else if (mirostat == 2)
|
||||||
|
{
|
||||||
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
|
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||||
|
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
// Temperature sampling
|
// Temperature sampling
|
||||||
llama_sample_top_k(nullptr, &candidates_p, top_k);
|
llama_sample_top_k(nullptr, &candidates_p, top_k);
|
||||||
llama_sample_tail_free(nullptr, &candidates_p, tfs);
|
llama_sample_tail_free(nullptr, &candidates_p, tfs);
|
||||||
|
@ -128,6 +201,7 @@ int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range
|
||||||
llama_sample_temperature(nullptr, &candidates_p, temp);
|
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||||
id = sample_token(&candidates_p, rng);
|
id = sample_token(&candidates_p, rng);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
@ -647,7 +721,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
logits[29961] = 0;
|
logits[29961] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng);
|
id = SampleLogits(logits, nctx, n_vocab, last_n_size, repeat_penalty,
|
||||||
|
top_k, top_p, typical_p, tfs_z, temp, rng,
|
||||||
|
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
||||||
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -667,7 +743,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
//gpt2 uses negative logits, so we cant zero it
|
//gpt2 uses negative logits, so we cant zero it
|
||||||
}
|
}
|
||||||
|
|
||||||
id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_p, typical_p, tfs_z, temp, rng);
|
id = SampleLogits(logits.data(), nctx, n_vocab, last_n_size, repeat_penalty,
|
||||||
|
top_k, top_p, typical_p, tfs_z, temp, rng,
|
||||||
|
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
||||||
}
|
}
|
||||||
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
|
|
22
koboldcpp.py
22
koboldcpp.py
|
@ -157,7 +157,7 @@ def load_model(model_filename):
|
||||||
ret = handle.load_model(inputs)
|
ret = handle.load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,mirostat=0,mirostat_lr=0.1,mirostat_ent=5.0,seed=-1,stop_sequence=[]):
|
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=100,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]):
|
||||||
inputs = generation_inputs()
|
inputs = generation_inputs()
|
||||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||||
inputs.prompt = prompt.encode("UTF-8")
|
inputs.prompt = prompt.encode("UTF-8")
|
||||||
|
@ -170,9 +170,12 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
|
||||||
inputs.tfs = tfs
|
inputs.tfs = tfs
|
||||||
inputs.rep_pen = rep_pen
|
inputs.rep_pen = rep_pen
|
||||||
inputs.rep_pen_range = rep_pen_range
|
inputs.rep_pen_range = rep_pen_range
|
||||||
inputs.mirostat = mirostat
|
if args.usemirostat and args.usemirostat[0]>0:
|
||||||
inputs.mirostat_eta = mirostat_lr
|
inputs.mirostat = int(args.usemirostat[0])
|
||||||
inputs.mirostat_tau = mirostat_ent
|
inputs.mirostat_tau = float(args.usemirostat[1])
|
||||||
|
inputs.mirostat_eta = float(args.usemirostat[2])
|
||||||
|
else:
|
||||||
|
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
|
||||||
inputs.seed = seed
|
inputs.seed = seed
|
||||||
for n in range(0,stop_token_max):
|
for n in range(0,stop_token_max):
|
||||||
if not stop_sequence or n >= len(stop_sequence):
|
if not stop_sequence or n >= len(stop_sequence):
|
||||||
|
@ -317,9 +320,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
tfs=genparams.get('tfs', 1.0),
|
tfs=genparams.get('tfs', 1.0),
|
||||||
rep_pen=genparams.get('rep_pen', 1.1),
|
rep_pen=genparams.get('rep_pen', 1.1),
|
||||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||||
mirostat=genparams.get('mirostat', 0),
|
|
||||||
mirostat_lr=genparams.get('mirostat_lr', 0.1),
|
|
||||||
mirostat_ent=genparams.get('mirostat_ent', 5.0),
|
|
||||||
seed=-1,
|
seed=-1,
|
||||||
stop_sequence=genparams.get('stop_sequence', [])
|
stop_sequence=genparams.get('stop_sequence', [])
|
||||||
)
|
)
|
||||||
|
@ -336,9 +336,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
tfs=genparams.get('tfs', 1.0),
|
tfs=genparams.get('tfs', 1.0),
|
||||||
rep_pen=genparams.get('rep_pen', 1.1),
|
rep_pen=genparams.get('rep_pen', 1.1),
|
||||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||||
mirostat=genparams.get('mirostat', 0),
|
|
||||||
mirostat_lr=genparams.get('mirostat_lr', 0.1),
|
|
||||||
mirostat_ent=genparams.get('mirostat_ent', 5.0),
|
|
||||||
seed=-1,
|
seed=-1,
|
||||||
stop_sequence=genparams.get('stop_sequence', [])
|
stop_sequence=genparams.get('stop_sequence', [])
|
||||||
)
|
)
|
||||||
|
@ -620,14 +617,15 @@ if __name__ == '__main__':
|
||||||
physical_core_limit = int(os.cpu_count()/2)
|
physical_core_limit = int(os.cpu_count()/2)
|
||||||
default_threads = (physical_core_limit if physical_core_limit<=3 else max(3,physical_core_limit-1))
|
default_threads = (physical_core_limit if physical_core_limit<=3 else max(3,physical_core_limit-1))
|
||||||
parser.add_argument("--threads", help="Use a custom number of threads if specified. Otherwise, uses an amount based on CPU cores", type=int, default=default_threads)
|
parser.add_argument("--threads", help="Use a custom number of threads if specified. Otherwise, uses an amount based on CPU cores", type=int, default=default_threads)
|
||||||
parser.add_argument("--blasthreads", help="Use a different number of threads during BLAS if specified. Otherwise, has the same value as --threads", type=int, default=0)
|
parser.add_argument("--blasthreads", help="Use a different number of threads during BLAS if specified. Otherwise, has the same value as --threads",metavar=('[threads]'), type=int, default=0)
|
||||||
parser.add_argument("--psutil_set_threads", help="Experimental flag. If set, uses psutils to determine thread count based on physical cores.", action='store_true')
|
parser.add_argument("--psutil_set_threads", help="Experimental flag. If set, uses psutils to determine thread count based on physical cores.", action='store_true')
|
||||||
parser.add_argument("--highpriority", help="Experimental flag. If set, increases the process CPU priority, potentially speeding up generation. Use caution.", action='store_true')
|
parser.add_argument("--highpriority", help="Experimental flag. If set, increases the process CPU priority, potentially speeding up generation. Use caution.", action='store_true')
|
||||||
parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512)", type=int,choices=[32,64,128,256,512,1024], default=512)
|
parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512)", type=int,choices=[32,64,128,256,512,1024], default=512)
|
||||||
parser.add_argument("--stream", help="Uses pseudo streaming when generating tokens. Only for the Kobold Lite UI.", action='store_true')
|
parser.add_argument("--stream", help="Uses pseudo streaming when generating tokens. Only for the Kobold Lite UI.", action='store_true')
|
||||||
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
|
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
|
||||||
parser.add_argument("--unbantokens", help="Normally, KoboldAI prevents certain tokens such as EOS and Square Brackets. This flag unbans them.", action='store_true')
|
parser.add_argument("--unbantokens", help="Normally, KoboldAI prevents certain tokens such as EOS and Square Brackets. This flag unbans them.", action='store_true')
|
||||||
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).", type=int, default=0)
|
parser.add_argument("--usemirostat", help="Experimental! Replaces your samplers with mirostat. Takes 3 params = [type(0/1/2), tau(5.0), eta(0.1)].",metavar=('[type]', '[tau]', '[eta]'), type=float, nargs=3)
|
||||||
|
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
|
||||||
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
|
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
|
||||||
parser.add_argument("--usemlock", help="For Apple Systems. Force system to keep model in RAM rather than swapping or compressing", action='store_true')
|
parser.add_argument("--usemlock", help="For Apple Systems. Force system to keep model in RAM rather than swapping or compressing", action='store_true')
|
||||||
parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true')
|
parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true')
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue