Merge remote-tracking branch 'ycros/improve-sampler-api-access' into concedo_experimental
This commit is contained in:
commit
784628a2be
3 changed files with 84 additions and 8 deletions
14
expose.h
14
expose.h
|
@ -1,6 +1,18 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
const int stop_token_max = 10;
|
const int stop_token_max = 10;
|
||||||
|
// match kobold's sampler list and order
|
||||||
|
enum samplers
|
||||||
|
{
|
||||||
|
KCPP_SAMPLER_TOP_K,
|
||||||
|
KCPP_SAMPLER_TOP_A,
|
||||||
|
KCPP_SAMPLER_TOP_P,
|
||||||
|
KCPP_SAMPLER_TFS,
|
||||||
|
KCPP_SAMPLER_TYP,
|
||||||
|
KCPP_SAMPLER_TEMP,
|
||||||
|
KCPP_SAMPLER_REP_PEN,
|
||||||
|
KCPP_SAMPLER_MAX
|
||||||
|
};
|
||||||
struct load_model_inputs
|
struct load_model_inputs
|
||||||
{
|
{
|
||||||
const int threads;
|
const int threads;
|
||||||
|
@ -40,6 +52,8 @@ struct generation_inputs
|
||||||
const int mirostat = 0;
|
const int mirostat = 0;
|
||||||
const float mirostat_eta;
|
const float mirostat_eta;
|
||||||
const float mirostat_tau;
|
const float mirostat_tau;
|
||||||
|
const samplers sampler_order[KCPP_SAMPLER_MAX];
|
||||||
|
const int sampler_len;
|
||||||
const char * stop_sequence[stop_token_max];
|
const char * stop_sequence[stop_token_max];
|
||||||
const bool stream_sse;
|
const bool stream_sse;
|
||||||
};
|
};
|
||||||
|
|
|
@ -219,8 +219,16 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
|
||||||
candidates->size = last_idx;
|
candidates->size = last_idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void apply_penalties(int n_ctx, int rep_pen_range, float rep_pen, llama_token_data_array & candidates_p)
|
||||||
|
{
|
||||||
|
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
|
||||||
|
llama_sample_repetition_penalty(nullptr, &candidates_p,
|
||||||
|
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||||
|
last_n_repeat, rep_pen);
|
||||||
|
}
|
||||||
|
|
||||||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
|
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float top_k, float top_a, float top_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
|
||||||
int mirostat, float mirostat_tau, float mirostat_eta)
|
int mirostat, float mirostat_tau, float mirostat_eta, uint sampler_len, const samplers sampler_order[KCPP_SAMPLER_MAX])
|
||||||
{
|
{
|
||||||
int id = 0;
|
int id = 0;
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
|
@ -231,11 +239,11 @@ int mirostat, float mirostat_tau, float mirostat_eta)
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
// Apply penalties
|
// Run this except for when we are going to do the sampler reordering case below
|
||||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
|
if (temp <= 0 || mirostat > 0 || sampler_len == 0)
|
||||||
llama_sample_repetition_penalty(nullptr, &candidates_p,
|
{
|
||||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p);
|
||||||
last_n_repeat, rep_pen);
|
}
|
||||||
|
|
||||||
// llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
|
// llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p,
|
||||||
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
// last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
||||||
|
@ -261,6 +269,37 @@ int mirostat, float mirostat_tau, float mirostat_eta)
|
||||||
llama_sample_temperature(nullptr, &candidates_p, temp);
|
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||||
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
|
id = sample_token_mirostat_v2(&candidates_p, rng, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||||
}
|
}
|
||||||
|
else if (sampler_len > 0)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < sampler_len; i++) {
|
||||||
|
switch (sampler_order[i]) {
|
||||||
|
case KCPP_SAMPLER_TOP_K:
|
||||||
|
llama_sample_top_k(nullptr, &candidates_p, top_k,1);
|
||||||
|
break;
|
||||||
|
case KCPP_SAMPLER_TOP_A:
|
||||||
|
sample_top_a(&candidates_p,top_a,1);
|
||||||
|
break;
|
||||||
|
case KCPP_SAMPLER_TOP_P:
|
||||||
|
llama_sample_top_p(nullptr, &candidates_p, top_p,1);
|
||||||
|
break;
|
||||||
|
case KCPP_SAMPLER_TFS:
|
||||||
|
llama_sample_tail_free(nullptr, &candidates_p, tfs,1);
|
||||||
|
break;
|
||||||
|
case KCPP_SAMPLER_TYP:
|
||||||
|
llama_sample_typical(nullptr, &candidates_p, typical_p,1);
|
||||||
|
break;
|
||||||
|
case KCPP_SAMPLER_TEMP:
|
||||||
|
llama_sample_temperature(nullptr, &candidates_p, temp);
|
||||||
|
break;
|
||||||
|
case KCPP_SAMPLER_REP_PEN:
|
||||||
|
apply_penalties(n_ctx, rep_pen_range, rep_pen, candidates_p);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
id = sample_token(&candidates_p, rng);
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
// Temperature sampling
|
// Temperature sampling
|
||||||
|
@ -1235,7 +1274,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
|
|
||||||
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
|
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
|
||||||
top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
|
top_k, top_a, top_p, typical_p, tfs_z, temp, rng,
|
||||||
params.mirostat,params.mirostat_tau,params.mirostat_eta);
|
params.mirostat, params.mirostat_tau, params.mirostat_eta,
|
||||||
|
inputs.sampler_len, inputs.sampler_order);
|
||||||
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
last_n_tokens.push_back(id);
|
last_n_tokens.push_back(id);
|
||||||
|
|
24
koboldcpp.py
24
koboldcpp.py
|
@ -9,6 +9,7 @@ import json, sys, http.server, time, asyncio, socket, threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
stop_token_max = 10
|
stop_token_max = 10
|
||||||
|
sampler_order_max = 7
|
||||||
|
|
||||||
class load_model_inputs(ctypes.Structure):
|
class load_model_inputs(ctypes.Structure):
|
||||||
_fields_ = [("threads", ctypes.c_int),
|
_fields_ = [("threads", ctypes.c_int),
|
||||||
|
@ -47,6 +48,8 @@ class generation_inputs(ctypes.Structure):
|
||||||
("mirostat", ctypes.c_int),
|
("mirostat", ctypes.c_int),
|
||||||
("mirostat_tau", ctypes.c_float),
|
("mirostat_tau", ctypes.c_float),
|
||||||
("mirostat_eta", ctypes.c_float),
|
("mirostat_eta", ctypes.c_float),
|
||||||
|
("sampler_order", ctypes.c_int * sampler_order_max),
|
||||||
|
("sampler_len", ctypes.c_int),
|
||||||
("stop_sequence", ctypes.c_char_p * stop_token_max),
|
("stop_sequence", ctypes.c_char_p * stop_token_max),
|
||||||
("stream_sse", ctypes.c_bool)]
|
("stream_sse", ctypes.c_bool)]
|
||||||
|
|
||||||
|
@ -186,7 +189,7 @@ def load_model(model_filename):
|
||||||
ret = handle.load_model(inputs)
|
ret = handle.load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[],stream_sse=False):
|
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=None, seed=-1, stop_sequence=[], stream_sse=False):
|
||||||
inputs = generation_inputs()
|
inputs = generation_inputs()
|
||||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||||
inputs.prompt = prompt.encode("UTF-8")
|
inputs.prompt = prompt.encode("UTF-8")
|
||||||
|
@ -205,8 +208,19 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
|
||||||
inputs.mirostat = int(args.usemirostat[0])
|
inputs.mirostat = int(args.usemirostat[0])
|
||||||
inputs.mirostat_tau = float(args.usemirostat[1])
|
inputs.mirostat_tau = float(args.usemirostat[1])
|
||||||
inputs.mirostat_eta = float(args.usemirostat[2])
|
inputs.mirostat_eta = float(args.usemirostat[2])
|
||||||
|
elif mirostat in (1, 2):
|
||||||
|
inputs.mirostat = mirostat
|
||||||
|
inputs.mirostat_tau = mirostat_tau
|
||||||
|
inputs.mirostat_eta = mirostat_eta
|
||||||
else:
|
else:
|
||||||
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
|
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
|
||||||
|
if sampler_order and 0 < len(sampler_order) <= sampler_order_max:
|
||||||
|
try:
|
||||||
|
for i, sampler in enumerate(sampler_order):
|
||||||
|
inputs.sampler_order[i] = sampler
|
||||||
|
inputs.sampler_len = len(sampler_order)
|
||||||
|
except TypeError as e:
|
||||||
|
print("ERROR: sampler_order must be a list of integers: " + str(e))
|
||||||
inputs.seed = seed
|
inputs.seed = seed
|
||||||
for n in range(stop_token_max):
|
for n in range(stop_token_max):
|
||||||
if not stop_sequence or n >= len(stop_sequence):
|
if not stop_sequence or n >= len(stop_sequence):
|
||||||
|
@ -272,6 +286,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
tfs=genparams.get('tfs', 1.0),
|
tfs=genparams.get('tfs', 1.0),
|
||||||
rep_pen=genparams.get('rep_pen', 1.1),
|
rep_pen=genparams.get('rep_pen', 1.1),
|
||||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||||
|
mirostat=genparams.get('mirostat', 0),
|
||||||
|
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||||
|
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||||
|
sampler_order=genparams.get('sampler_order', None),
|
||||||
seed=genparams.get('sampler_seed', -1),
|
seed=genparams.get('sampler_seed', -1),
|
||||||
stop_sequence=genparams.get('stop_sequence', []),
|
stop_sequence=genparams.get('stop_sequence', []),
|
||||||
stream_sse=stream_flag)
|
stream_sse=stream_flag)
|
||||||
|
@ -288,6 +306,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||||
tfs=genparams.get('tfs', 1.0),
|
tfs=genparams.get('tfs', 1.0),
|
||||||
rep_pen=genparams.get('rep_pen', 1.1),
|
rep_pen=genparams.get('rep_pen', 1.1),
|
||||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||||
|
mirostat=genparams.get('mirostat', 0),
|
||||||
|
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||||
|
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||||
|
sampler_order=genparams.get('sampler_order', None),
|
||||||
seed=genparams.get('sampler_seed', -1),
|
seed=genparams.get('sampler_seed', -1),
|
||||||
stop_sequence=genparams.get('stop_sequence', []),
|
stop_sequence=genparams.get('stop_sequence', []),
|
||||||
stream_sse=stream_flag)
|
stream_sse=stream_flag)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue