handle memory separately for kcpp

This commit is contained in:
Concedo 2023-11-07 17:15:14 +08:00
parent f277ed0e8c
commit fb3bcac368
4 changed files with 105 additions and 22 deletions

View file

@ -54,6 +54,7 @@ struct generation_inputs
{
const int seed;
const char * prompt;
const char * memory;
const int max_context_length;
const int max_length;
const float temperature;
@ -79,7 +80,7 @@ struct generation_inputs
struct generation_outputs
{
int status = -1;
char text[24576]; //24kb should be enough for any response
char text[32768]; //32kb should be enough for any response
};
extern std::string executable_path;

View file

@ -1388,6 +1388,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
stop_sequence.push_back(stopper);
}
}
std::string addedmemory = inputs.memory;
params.prompt = inputs.prompt;
params.seed = inputs.seed;
params.n_predict = inputs.max_length;
@ -1442,7 +1443,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// tokenize the prompt
std::vector<int> embd_inp;
std::vector<int> embd_inp_mem; //for storing added memory
TokenizeString(params.prompt, embd_inp, file_format);
if(addedmemory!="")
{
TokenizeString(addedmemory, embd_inp_mem, file_format);
}
//truncate to front of the prompt if its too long
int32_t nctx = params.n_ctx;
@ -1461,6 +1467,46 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
}
}
//added special memory, overwrite if needed
if(addedmemory!="")
{
//remove bos token from prompt, it'll be taken from memory
std::vector<int> bos;
TokenizeString("", bos, file_format);
if (bos.size()>0 && !embd_inp.empty() && bos[0]==embd_inp[0]) {
embd_inp.erase(embd_inp.begin());
}
//shorten memory if needed
if (embd_inp_mem.size() + params.n_predict + 4 > nctx)
{
int offset = embd_inp_mem.size() - nctx + params.n_predict + 4;
embd_inp_mem = std::vector<int>(embd_inp_mem.begin() + offset, embd_inp_mem.end());
//replace bos into front if exists
if(bos.size()>0 && embd_inp_mem.size()>0)
{
embd_inp_mem[0] = bos[0];
}
}
//shorten main prompt by trimming the front if needed
int addmemtokens = embd_inp_mem.size();
int totalsize = (addmemtokens + embd_inp.size() + params.n_predict);
if(totalsize > nctx)
{
int excess = totalsize - nctx;
if (embd_inp.size() >= excess) {
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + excess);
} else {
embd_inp.clear();
}
}
//stick memory to front of prompt
embd_inp.insert(embd_inp.begin(), embd_inp_mem.begin(), embd_inp_mem.end());
}
//determine how much npast we have to rewind from the current state
std::vector<gpt_vocab::id> embd;

View file

@ -6,7 +6,7 @@ It requires no dependencies, installation or setup.
Just copy this single static HTML file anywhere and open it in a browser, or from a webserver.
Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite.
Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line.
Current version: 92
Current version: 93
-Concedo
-->
@ -4009,6 +4009,10 @@ Current version: 92
{
return (custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.43") > 0);
}
function is_using_kcpp_with_added_memory()
{
return (custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.48.2") > 0);
}
//0 is none, 1 is pseudostreaming, 2 is true poll-streaming, 3 is sse-streaming
function determine_streaming_type()
@ -7354,8 +7358,8 @@ Current version: 92
let max_allowed_characters = Math.floor(localsettings.max_context_length * 3.0)-100;
let truncated_context = concat_gametext(true, "");
let max_mem_anote_len = Math.floor(max_allowed_characters*0.9);
let truncated_memory = current_memory.substring(current_memory.length - max_mem_anote_len);
let max_mem_len = Math.floor(max_allowed_characters*0.8);
let truncated_memory = current_memory.substring(current_memory.length - max_mem_len);
if (truncated_memory != null && truncated_memory != "") {
truncated_memory += "\n";
}
@ -7933,6 +7937,10 @@ Current version: 92
//if there is no memory, then we can be a lot of lenient with the character counts since the backend will truncate excess anyway
chars_per_token = 4.8;
}
if(is_using_kcpp_with_added_memory()) //easily handle overflow
{
chars_per_token = 6;
}
let max_allowed_characters = Math.max(1, Math.floor((maxctxlen-maxgenamt) * chars_per_token) - 12);
//for adventure mode, inject hidden context, even more if there's nothing in memory
@ -8059,9 +8067,10 @@ Current version: 92
}
//we clip the memory if its too long, taking the last x chars (not the first)
//memory or anote is allowed to be up to 0.9 times of ctx allowance
let max_mem_anote_len = Math.floor(max_allowed_characters*0.9);
let truncated_memory = substring_to_boundary(current_memory, max_mem_anote_len);
//memory is allowed to be up to 0.8 times of ctx allowance, anote up to 0.6 times
let max_mem_len = Math.floor(max_allowed_characters*0.8);
let max_anote_len = Math.floor(max_allowed_characters*0.6);
let truncated_memory = substring_to_boundary(current_memory, max_mem_len);
if (truncated_memory != null && truncated_memory != "") {
if(newlineaftermemory)
{
@ -8129,23 +8138,29 @@ Current version: 92
//we clip the authors note if its too long
let truncated_anote = current_anotetemplate.replace("<|>", current_anote);
truncated_anote = substring_to_boundary(truncated_anote, max_mem_anote_len);
truncated_anote = substring_to_boundary(truncated_anote, max_anote_len);
if (current_anote.length == 0) {
//if there's no authors note at all, don't include the template
truncated_anote = "";
}
//append memory to the start of the context, clipping excess space if needed
//only do this processing if memory or anote is not blank
if (truncated_memory.length > 0 || current_anote.length > 0) {
//now we resize the context such that the memory and authors note can fit inside
truncated_context = substring_to_boundary(truncated_context, max_allowed_characters);
//append memory to the start of the context, clipping excess space if needed
//only do this processing if memory or anote is not blank
if (truncated_memory.length > 0 || current_anote.length > 0)
{
if(!is_using_kcpp_with_added_memory())
{
let augmented_len = truncated_memory.length + truncated_context.length + truncated_anote.length;
let excess_len = augmented_len - max_allowed_characters; //if > 0, then we exceeded context window
excess_len = excess_len < 0 ? 0 : excess_len;
let newlimit = (max_allowed_characters-excess_len) < 32 ? 32 : (max_allowed_characters-excess_len);
truncated_context = substring_to_boundary(truncated_context, newlimit); //must always have at least 32 chars from main context
}
//insert authors note 80 tokens before the ending (320 characters).
let anote_dist = anote_strength;
let anote_insert_idx = truncated_context.length - anote_dist;
@ -8164,11 +8179,24 @@ Current version: 92
}
anote_insert_idx = clamp(anote_insert_idx, 0, truncated_context.length);
truncated_context = truncated_context.slice(0, anote_insert_idx) + truncated_anote + truncated_context.slice(anote_insert_idx);
if(!is_using_kcpp_with_added_memory())
{
truncated_context = truncated_memory + truncated_context;
}
}
truncated_memory = replace_placeholders(truncated_memory);
truncated_context = replace_placeholders(truncated_context);
if(is_using_kcpp_with_added_memory())
{
last_token_budget = (truncated_memory.length + truncated_context.length) + "/" + max_allowed_characters;
}
else
{
last_token_budget = truncated_context.length + "/" + max_allowed_characters;
}
let submit_payload = {
"prompt": truncated_context,
@ -8190,6 +8218,11 @@ Current version: 92
"models": selected_models.map((m) => { return m.name }),
};
if(is_using_kcpp_with_added_memory())
{
submit_payload.params.memory = truncated_memory;
}
if(localsettings.sampler_seed>=1)
{
submit_payload.params.sampler_seed = localsettings.sampler_seed;

View file

@ -49,6 +49,7 @@ class load_model_inputs(ctypes.Structure):
class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int),
("prompt", ctypes.c_char_p),
("memory", ctypes.c_char_p),
("max_context_length", ctypes.c_int),
("max_length", ctypes.c_int),
("temperature", ctypes.c_float),
@ -73,7 +74,7 @@ class generation_inputs(ctypes.Structure):
class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("text", ctypes.c_char * 24576)]
("text", ctypes.c_char * 32768)]
handle = None
@ -297,11 +298,12 @@ def load_model(model_filename):
ret = handle.load_model(inputs)
return ret
def generate(prompt,max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''):
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''):
global maxctx, args, currentusergenkey, totalgens
inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
inputs.prompt = prompt.encode("UTF-8")
inputs.memory = memory.encode("UTF-8")
if max_length >= max_context_length:
max_length = max_context_length-1
inputs.max_context_length = max_context_length # this will resize the context buffer if changed
@ -379,7 +381,7 @@ maxhordelen = 256
modelbusy = threading.Lock()
requestsinqueue = 0
defaultport = 5001
KcppVersion = "1.48.1"
KcppVersion = "1.49"
showdebug = True
showsamplerwarning = True
showmaxctxwarning = True
@ -474,6 +476,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
return generate(
prompt=genparams.get('prompt', ""),
memory=genparams.get('memory', ""),
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 80),
temperature=genparams.get('temperature', 0.7),