added the ability to ban any substring tokens
This commit is contained in:
parent
27a0907cfa
commit
8424a35c62
3 changed files with 62 additions and 2 deletions
2
expose.h
2
expose.h
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
const int stop_token_max = 10;
|
const int stop_token_max = 10;
|
||||||
|
const int ban_token_max = 10;
|
||||||
// match kobold's sampler list and order
|
// match kobold's sampler list and order
|
||||||
enum samplers
|
enum samplers
|
||||||
{
|
{
|
||||||
|
@ -35,6 +36,7 @@ struct load_model_inputs
|
||||||
const int debugmode = 0;
|
const int debugmode = 0;
|
||||||
const int forceversion = 0;
|
const int forceversion = 0;
|
||||||
const int gpulayers = 0;
|
const int gpulayers = 0;
|
||||||
|
const char * banned_tokens[ban_token_max];
|
||||||
};
|
};
|
||||||
struct generation_inputs
|
struct generation_inputs
|
||||||
{
|
{
|
||||||
|
|
|
@ -76,6 +76,8 @@ static size_t mem_per_token = 0;
|
||||||
static std::vector<float> logits;
|
static std::vector<float> logits;
|
||||||
static std::vector<int> smartcontext;
|
static std::vector<int> smartcontext;
|
||||||
static std::vector<std::string> stop_sequence;
|
static std::vector<std::string> stop_sequence;
|
||||||
|
static std::vector<std::string> banned_tokens;
|
||||||
|
static std::vector<int> banned_token_ids;
|
||||||
static std::vector<llama_token_data> top_picks;
|
static std::vector<llama_token_data> top_picks;
|
||||||
static int remaining_tokens = 0;
|
static int remaining_tokens = 0;
|
||||||
static int stopper_unused_tokens = 0;
|
static int stopper_unused_tokens = 0;
|
||||||
|
@ -344,6 +346,17 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
= gpt2_ctx_v1.hparams.n_ctx = gpt2_ctx_v2.hparams.n_ctx = gpt2_ctx_v3.hparams.n_ctx
|
= gpt2_ctx_v1.hparams.n_ctx = gpt2_ctx_v2.hparams.n_ctx = gpt2_ctx_v3.hparams.n_ctx
|
||||||
= mpt_ctx_v3.hparams.n_ctx = params.n_ctx;
|
= mpt_ctx_v3.hparams.n_ctx = params.n_ctx;
|
||||||
|
|
||||||
|
//handle custom token bans
|
||||||
|
banned_tokens.clear();
|
||||||
|
for(int x=0;x<ban_token_max;++x)
|
||||||
|
{
|
||||||
|
std::string word = inputs.banned_tokens[x];
|
||||||
|
if(word!="")
|
||||||
|
{
|
||||||
|
banned_tokens.push_back(word);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//this is used for the mem_per_token eval, openblas needs more RAM
|
//this is used for the mem_per_token eval, openblas needs more RAM
|
||||||
bool use_scratch = ggml_cpu_has_gpublas();
|
bool use_scratch = ggml_cpu_has_gpublas();
|
||||||
|
|
||||||
|
@ -1064,6 +1077,25 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
printf("Bad format!");
|
printf("Bad format!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//prepare banned tokens
|
||||||
|
if(banned_token_ids.size()==0 && banned_tokens.size()>0)
|
||||||
|
{
|
||||||
|
printf("\n[First Run] Banning %d token sequences...",banned_tokens.size());
|
||||||
|
for(int v=0;v<n_vocab;++v)
|
||||||
|
{
|
||||||
|
std::string word = FileFormatTokenizeID(v,file_format);
|
||||||
|
for(int i=0;i<banned_tokens.size();++i)
|
||||||
|
{
|
||||||
|
if (word.find(banned_tokens[i]) != std::string::npos)
|
||||||
|
{
|
||||||
|
banned_token_ids.push_back(v);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("\nBanned a total of %d tokens.\n",banned_token_ids.size());
|
||||||
|
}
|
||||||
|
|
||||||
if(debugmode!=-1)
|
if(debugmode!=-1)
|
||||||
{
|
{
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
@ -1221,6 +1253,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
|
|
||||||
unsigned int eosID = 0;
|
unsigned int eosID = 0;
|
||||||
float * logitsPtr;
|
float * logitsPtr;
|
||||||
|
int btsize = banned_token_ids.size();
|
||||||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3)
|
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3)
|
||||||
{
|
{
|
||||||
if(file_format == FileFormat::GGJT_3)
|
if(file_format == FileFormat::GGJT_3)
|
||||||
|
@ -1239,6 +1272,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
// set the logit of the eos token (2) to zero to avoid sampling it
|
// set the logit of the eos token (2) to zero to avoid sampling it
|
||||||
logitsPtr[eosID] = 0;
|
logitsPtr[eosID] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(btsize>0)
|
||||||
|
{
|
||||||
|
for(int t=0;t<btsize;++t)
|
||||||
|
{
|
||||||
|
logitsPtr[banned_token_ids[t]]=0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -1293,6 +1334,14 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(btsize>0)
|
||||||
|
{
|
||||||
|
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
|
||||||
|
for (int t = 0; t < btsize; ++t)
|
||||||
|
{
|
||||||
|
logits[banned_token_ids[t]] = (logits[topid] < 0 ? logits[topid] : 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
|
id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty,
|
||||||
|
|
13
koboldcpp.py
13
koboldcpp.py
|
@ -13,6 +13,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
stop_token_max = 10
|
stop_token_max = 10
|
||||||
sampler_order_max = 7
|
sampler_order_max = 7
|
||||||
|
ban_token_max = 10
|
||||||
|
|
||||||
class load_model_inputs(ctypes.Structure):
|
class load_model_inputs(ctypes.Structure):
|
||||||
_fields_ = [("threads", ctypes.c_int),
|
_fields_ = [("threads", ctypes.c_int),
|
||||||
|
@ -34,7 +35,8 @@ class load_model_inputs(ctypes.Structure):
|
||||||
("blasbatchsize", ctypes.c_int),
|
("blasbatchsize", ctypes.c_int),
|
||||||
("debugmode", ctypes.c_int),
|
("debugmode", ctypes.c_int),
|
||||||
("forceversion", ctypes.c_int),
|
("forceversion", ctypes.c_int),
|
||||||
("gpulayers", ctypes.c_int)]
|
("gpulayers", ctypes.c_int),
|
||||||
|
("banned_tokens", ctypes.c_char_p * ban_token_max)]
|
||||||
|
|
||||||
class generation_inputs(ctypes.Structure):
|
class generation_inputs(ctypes.Structure):
|
||||||
_fields_ = [("seed", ctypes.c_int),
|
_fields_ = [("seed", ctypes.c_int),
|
||||||
|
@ -195,6 +197,13 @@ def load_model(model_filename):
|
||||||
inputs.cublas_info = 2
|
inputs.cublas_info = 2
|
||||||
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
|
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
|
||||||
inputs.debugmode = args.debugmode
|
inputs.debugmode = args.debugmode
|
||||||
|
banned_tokens = args.bantokens
|
||||||
|
print(banned_tokens)
|
||||||
|
for n in range(ban_token_max):
|
||||||
|
if not banned_tokens or n >= len(banned_tokens):
|
||||||
|
inputs.banned_tokens[n] = "".encode("UTF-8")
|
||||||
|
else:
|
||||||
|
inputs.banned_tokens[n] = banned_tokens[n].encode("UTF-8")
|
||||||
ret = handle.load_model(inputs)
|
ret = handle.load_model(inputs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@ -1297,7 +1306,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument("--stream", help="Uses streaming when generating tokens. Only for the Kobold Lite UI.", action='store_true')
|
parser.add_argument("--stream", help="Uses streaming when generating tokens. Only for the Kobold Lite UI.", action='store_true')
|
||||||
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
|
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
|
||||||
parser.add_argument("--unbantokens", help="Normally, KoboldAI prevents the EOS token from being generated. This flag unbans it.", action='store_true')
|
parser.add_argument("--unbantokens", help="Normally, KoboldAI prevents the EOS token from being generated. This flag unbans it.", action='store_true')
|
||||||
parser.add_argument("--bantokens", help="You can manually specify a list of token IDs that the AI cannot use.", metavar=('[elements]'), nargs='+')
|
parser.add_argument("--bantokens", help="You can manually specify a list of token SUBSTRINGS that the AI cannot use. This bans ALL instances of that substring.", metavar=('[token_substrings]'), nargs='+')
|
||||||
parser.add_argument("--usemirostat", help="Experimental! Replaces your samplers with mirostat. Takes 3 params = [type(0/1/2), tau(5.0), eta(0.1)].",metavar=('[type]', '[tau]', '[eta]'), type=float, nargs=3)
|
parser.add_argument("--usemirostat", help="Experimental! Replaces your samplers with mirostat. Takes 3 params = [type(0/1/2), tau(5.0), eta(0.1)].",metavar=('[type]', '[tau]', '[eta]'), type=float, nargs=3)
|
||||||
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
|
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).",metavar=('[version]'), type=int, default=0)
|
||||||
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
|
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue