added token unbanning
This commit is contained in:
parent
1b9b9068b1
commit
3962eb39c7
3 changed files with 32 additions and 16 deletions
1
expose.h
1
expose.h
|
@ -12,6 +12,7 @@ struct load_model_inputs
|
|||
const char * lora_filename;
|
||||
const bool use_mmap;
|
||||
const bool use_smartcontext;
|
||||
const bool unban_tokens;
|
||||
const int clblast_info = 0;
|
||||
const int blasbatchsize = 512;
|
||||
};
|
||||
|
|
|
@ -41,6 +41,7 @@ static int n_past = 0;
|
|||
static int n_threads = 4;
|
||||
static int n_batch = 8;
|
||||
static bool useSmartContext = false;
|
||||
static bool unbanTokens = false;
|
||||
static int blasbatchsize = 512;
|
||||
static std::string modelname;
|
||||
static std::vector<gpt_vocab::id> last_n_tokens;
|
||||
|
@ -65,6 +66,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
n_batch = params.n_batch = inputs.batch_size;
|
||||
modelname = params.model = inputs.model_filename;
|
||||
useSmartContext = inputs.use_smartcontext;
|
||||
unbanTokens = inputs.unban_tokens;
|
||||
blasbatchsize = inputs.blasbatchsize;
|
||||
params.memory_f16 = inputs.f16_kv;
|
||||
params.n_ctx = inputs.max_context_length;
|
||||
|
@ -366,7 +368,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
|
||||
params.n_batch = bbs; //received reports of 1024 and above crashing on some models
|
||||
//params.n_threads = 1; //do not limit here anymore.
|
||||
if(!ggml_cpu_has_cublas())
|
||||
{
|
||||
params.n_threads = 1; //do not limit here anymore.
|
||||
}
|
||||
}
|
||||
|
||||
current_context_tokens.resize(n_past);
|
||||
|
@ -512,28 +517,35 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT)
|
||||
{
|
||||
auto logits = llama_get_logits(llama_ctx_v1);
|
||||
// set the logit of the eos token (2) to zero to avoid sampling it
|
||||
logits[llama_token_eos()] = 0;
|
||||
//set logits of opening square bracket to zero.
|
||||
logits[518] = 0;
|
||||
logits[29961] = 0;
|
||||
|
||||
if (!unbanTokens)
|
||||
{
|
||||
// set the logit of the eos token (2) to zero to avoid sampling it
|
||||
logits[llama_token_eos()] = 0;
|
||||
//set logits of opening square bracket to zero.
|
||||
logits[518] = 0;
|
||||
logits[29961] = 0;
|
||||
}
|
||||
|
||||
id = llama_sample_top_p_top_k(llama_ctx_v1, last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_penalty);
|
||||
|
||||
}
|
||||
else
|
||||
{
|
||||
// set the logit of the eos token (2) to zero to avoid sampling it
|
||||
if((file_format == FileFormat::GPT2_1 ||
|
||||
file_format == FileFormat::GPT2_2 ||
|
||||
file_format == FileFormat::GPTJ_1 ||
|
||||
file_format == FileFormat::GPTJ_2 ||
|
||||
file_format == FileFormat::GPTJ_3)
|
||||
&& logits.size()>50256)
|
||||
if (!unbanTokens)
|
||||
{
|
||||
logits[50256] = (logits[50256] < 0 ? logits[50256] : 0);
|
||||
// set the logit of the eos token (2) to zero to avoid sampling it
|
||||
if ((file_format == FileFormat::GPT2_1 ||
|
||||
file_format == FileFormat::GPT2_2 ||
|
||||
file_format == FileFormat::GPTJ_1 ||
|
||||
file_format == FileFormat::GPTJ_2 ||
|
||||
file_format == FileFormat::GPTJ_3) &&
|
||||
logits.size() > 50256)
|
||||
{
|
||||
logits[50256] = (logits[50256] < 0 ? logits[50256] : 0);
|
||||
}
|
||||
//gpt2 uses negative logits, so we cant zero it
|
||||
}
|
||||
//gpt2 uses negative logits, so we cant zero it
|
||||
|
||||
id = gptj_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ class load_model_inputs(ctypes.Structure):
|
|||
("lora_filename", ctypes.c_char_p),
|
||||
("use_mmap", ctypes.c_bool),
|
||||
("use_smartcontext", ctypes.c_bool),
|
||||
("unban_tokens", ctypes.c_bool),
|
||||
("clblast_info", ctypes.c_int),
|
||||
("blasbatchsize", ctypes.c_int)]
|
||||
|
||||
|
@ -96,9 +97,10 @@ def load_model(model_filename):
|
|||
inputs.batch_size = 8
|
||||
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
|
||||
inputs.threads = args.threads
|
||||
inputs.f16_kv = True
|
||||
inputs.f16_kv = True
|
||||
inputs.use_mmap = (not args.nommap)
|
||||
inputs.use_smartcontext = args.smartcontext
|
||||
inputs.unban_tokens = args.unbantokens
|
||||
inputs.blasbatchsize = args.blasbatchsize
|
||||
clblastids = 0
|
||||
if args.useclblast:
|
||||
|
@ -497,6 +499,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--blasbatchsize", help="Sets the batch size used in BLAS processing (default 512)", type=int,choices=[32,64,128,256,512,1024], default=512)
|
||||
parser.add_argument("--stream", help="Uses pseudo streaming", action='store_true')
|
||||
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
|
||||
parser.add_argument("--unbantokens", help="Normally, KoboldAI prevents certain tokens such as EOS and Square Brackets. This flag unbans them.", action='store_true')
|
||||
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
|
||||
parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true')
|
||||
compatgroup = parser.add_mutually_exclusive_group()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue