handle memory separately for kcpp

This commit is contained in:
Concedo 2023-11-07 17:15:14 +08:00
parent f277ed0e8c
commit fb3bcac368
4 changed files with 105 additions and 22 deletions

View file

@ -53,7 +53,8 @@ struct load_model_inputs
struct generation_inputs struct generation_inputs
{ {
const int seed; const int seed;
const char *prompt; const char * prompt;
const char * memory;
const int max_context_length; const int max_context_length;
const int max_length; const int max_length;
const float temperature; const float temperature;
@ -79,7 +80,7 @@ struct generation_inputs
struct generation_outputs struct generation_outputs
{ {
int status = -1; int status = -1;
char text[24576]; //24kb should be enough for any response char text[32768]; //32kb should be enough for any response
}; };
extern std::string executable_path; extern std::string executable_path;

View file

@ -1388,6 +1388,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
stop_sequence.push_back(stopper); stop_sequence.push_back(stopper);
} }
} }
std::string addedmemory = inputs.memory;
params.prompt = inputs.prompt; params.prompt = inputs.prompt;
params.seed = inputs.seed; params.seed = inputs.seed;
params.n_predict = inputs.max_length; params.n_predict = inputs.max_length;
@ -1442,7 +1443,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
// tokenize the prompt // tokenize the prompt
std::vector<int> embd_inp; std::vector<int> embd_inp;
std::vector<int> embd_inp_mem; //for storing added memory
TokenizeString(params.prompt, embd_inp, file_format); TokenizeString(params.prompt, embd_inp, file_format);
if(addedmemory!="")
{
TokenizeString(addedmemory, embd_inp_mem, file_format);
}
//truncate to front of the prompt if its too long //truncate to front of the prompt if its too long
int32_t nctx = params.n_ctx; int32_t nctx = params.n_ctx;
@ -1461,6 +1467,46 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
} }
} }
//added special memory, overwrite if needed
if(addedmemory!="")
{
//remove bos token from prompt, it'll be taken from memory
std::vector<int> bos;
TokenizeString("", bos, file_format);
if (bos.size()>0 && !embd_inp.empty() && bos[0]==embd_inp[0]) {
embd_inp.erase(embd_inp.begin());
}
//shorten memory if needed
if (embd_inp_mem.size() + params.n_predict + 4 > nctx)
{
int offset = embd_inp_mem.size() - nctx + params.n_predict + 4;
embd_inp_mem = std::vector<int>(embd_inp_mem.begin() + offset, embd_inp_mem.end());
//replace bos into front if exists
if(bos.size()>0 && embd_inp_mem.size()>0)
{
embd_inp_mem[0] = bos[0];
}
}
//shorten main prompt by trimming the front if needed
int addmemtokens = embd_inp_mem.size();
int totalsize = (addmemtokens + embd_inp.size() + params.n_predict);
if(totalsize > nctx)
{
int excess = totalsize - nctx;
if (embd_inp.size() >= excess) {
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + excess);
} else {
embd_inp.clear();
}
}
//stick memory to front of prompt
embd_inp.insert(embd_inp.begin(), embd_inp_mem.begin(), embd_inp_mem.end());
}
//determine how much npast we have to rewind from the current state //determine how much npast we have to rewind from the current state
std::vector<gpt_vocab::id> embd; std::vector<gpt_vocab::id> embd;

View file

@ -6,7 +6,7 @@ It requires no dependencies, installation or setup.
Just copy this single static HTML file anywhere and open it in a browser, or from a webserver. Just copy this single static HTML file anywhere and open it in a browser, or from a webserver.
Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite. Please go to https://github.com/LostRuins/lite.koboldai.net for updates on Kobold Lite.
Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line. Kobold Lite is under the AGPL v3.0 License unless otherwise exempted. Please do not remove this line.
Current version: 92 Current version: 93
-Concedo -Concedo
--> -->
@ -4009,6 +4009,10 @@ Current version: 92
{ {
return (custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.43") > 0); return (custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.43") > 0);
} }
function is_using_kcpp_with_added_memory()
{
return (custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.48.2") > 0);
}
//0 is none, 1 is pseudostreaming, 2 is true poll-streaming, 3 is sse-streaming //0 is none, 1 is pseudostreaming, 2 is true poll-streaming, 3 is sse-streaming
function determine_streaming_type() function determine_streaming_type()
@ -7354,8 +7358,8 @@ Current version: 92
let max_allowed_characters = Math.floor(localsettings.max_context_length * 3.0)-100; let max_allowed_characters = Math.floor(localsettings.max_context_length * 3.0)-100;
let truncated_context = concat_gametext(true, ""); let truncated_context = concat_gametext(true, "");
let max_mem_anote_len = Math.floor(max_allowed_characters*0.9); let max_mem_len = Math.floor(max_allowed_characters*0.8);
let truncated_memory = current_memory.substring(current_memory.length - max_mem_anote_len); let truncated_memory = current_memory.substring(current_memory.length - max_mem_len);
if (truncated_memory != null && truncated_memory != "") { if (truncated_memory != null && truncated_memory != "") {
truncated_memory += "\n"; truncated_memory += "\n";
} }
@ -7933,6 +7937,10 @@ Current version: 92
//if there is no memory, then we can be a lot of lenient with the character counts since the backend will truncate excess anyway //if there is no memory, then we can be a lot of lenient with the character counts since the backend will truncate excess anyway
chars_per_token = 4.8; chars_per_token = 4.8;
} }
if(is_using_kcpp_with_added_memory()) //easily handle overflow
{
chars_per_token = 6;
}
let max_allowed_characters = Math.max(1, Math.floor((maxctxlen-maxgenamt) * chars_per_token) - 12); let max_allowed_characters = Math.max(1, Math.floor((maxctxlen-maxgenamt) * chars_per_token) - 12);
//for adventure mode, inject hidden context, even more if there's nothing in memory //for adventure mode, inject hidden context, even more if there's nothing in memory
@ -8059,9 +8067,10 @@ Current version: 92
} }
//we clip the memory if its too long, taking the last x chars (not the first) //we clip the memory if its too long, taking the last x chars (not the first)
//memory or anote is allowed to be up to 0.9 times of ctx allowance //memory is allowed to be up to 0.8 times of ctx allowance, anote up to 0.6 times
let max_mem_anote_len = Math.floor(max_allowed_characters*0.9); let max_mem_len = Math.floor(max_allowed_characters*0.8);
let truncated_memory = substring_to_boundary(current_memory, max_mem_anote_len); let max_anote_len = Math.floor(max_allowed_characters*0.6);
let truncated_memory = substring_to_boundary(current_memory, max_mem_len);
if (truncated_memory != null && truncated_memory != "") { if (truncated_memory != null && truncated_memory != "") {
if(newlineaftermemory) if(newlineaftermemory)
{ {
@ -8129,23 +8138,29 @@ Current version: 92
//we clip the authors note if its too long //we clip the authors note if its too long
let truncated_anote = current_anotetemplate.replace("<|>", current_anote); let truncated_anote = current_anotetemplate.replace("<|>", current_anote);
truncated_anote = substring_to_boundary(truncated_anote, max_mem_anote_len); truncated_anote = substring_to_boundary(truncated_anote, max_anote_len);
if (current_anote.length == 0) { if (current_anote.length == 0) {
//if there's no authors note at all, don't include the template //if there's no authors note at all, don't include the template
truncated_anote = ""; truncated_anote = "";
} }
//append memory to the start of the context, clipping excess space if needed
//only do this processing if memory or anote is not blank
if (truncated_memory.length > 0 || current_anote.length > 0) {
//now we resize the context such that the memory and authors note can fit inside //now we resize the context such that the memory and authors note can fit inside
truncated_context = substring_to_boundary(truncated_context, max_allowed_characters); truncated_context = substring_to_boundary(truncated_context, max_allowed_characters);
//append memory to the start of the context, clipping excess space if needed
//only do this processing if memory or anote is not blank
if (truncated_memory.length > 0 || current_anote.length > 0)
{
if(!is_using_kcpp_with_added_memory())
{
let augmented_len = truncated_memory.length + truncated_context.length + truncated_anote.length; let augmented_len = truncated_memory.length + truncated_context.length + truncated_anote.length;
let excess_len = augmented_len - max_allowed_characters; //if > 0, then we exceeded context window let excess_len = augmented_len - max_allowed_characters; //if > 0, then we exceeded context window
excess_len = excess_len < 0 ? 0 : excess_len; excess_len = excess_len < 0 ? 0 : excess_len;
let newlimit = (max_allowed_characters-excess_len) < 32 ? 32 : (max_allowed_characters-excess_len); let newlimit = (max_allowed_characters-excess_len) < 32 ? 32 : (max_allowed_characters-excess_len);
truncated_context = substring_to_boundary(truncated_context, newlimit); //must always have at least 32 chars from main context truncated_context = substring_to_boundary(truncated_context, newlimit); //must always have at least 32 chars from main context
}
//insert authors note 80 tokens before the ending (320 characters). //insert authors note 80 tokens before the ending (320 characters).
let anote_dist = anote_strength; let anote_dist = anote_strength;
let anote_insert_idx = truncated_context.length - anote_dist; let anote_insert_idx = truncated_context.length - anote_dist;
@ -8164,11 +8179,24 @@ Current version: 92
} }
anote_insert_idx = clamp(anote_insert_idx, 0, truncated_context.length); anote_insert_idx = clamp(anote_insert_idx, 0, truncated_context.length);
truncated_context = truncated_context.slice(0, anote_insert_idx) + truncated_anote + truncated_context.slice(anote_insert_idx); truncated_context = truncated_context.slice(0, anote_insert_idx) + truncated_anote + truncated_context.slice(anote_insert_idx);
if(!is_using_kcpp_with_added_memory())
{
truncated_context = truncated_memory + truncated_context; truncated_context = truncated_memory + truncated_context;
} }
}
truncated_memory = replace_placeholders(truncated_memory);
truncated_context = replace_placeholders(truncated_context); truncated_context = replace_placeholders(truncated_context);
if(is_using_kcpp_with_added_memory())
{
last_token_budget = (truncated_memory.length + truncated_context.length) + "/" + max_allowed_characters;
}
else
{
last_token_budget = truncated_context.length + "/" + max_allowed_characters; last_token_budget = truncated_context.length + "/" + max_allowed_characters;
}
let submit_payload = { let submit_payload = {
"prompt": truncated_context, "prompt": truncated_context,
@ -8190,6 +8218,11 @@ Current version: 92
"models": selected_models.map((m) => { return m.name }), "models": selected_models.map((m) => { return m.name }),
}; };
if(is_using_kcpp_with_added_memory())
{
submit_payload.params.memory = truncated_memory;
}
if(localsettings.sampler_seed>=1) if(localsettings.sampler_seed>=1)
{ {
submit_payload.params.sampler_seed = localsettings.sampler_seed; submit_payload.params.sampler_seed = localsettings.sampler_seed;

View file

@ -49,6 +49,7 @@ class load_model_inputs(ctypes.Structure):
class generation_inputs(ctypes.Structure): class generation_inputs(ctypes.Structure):
_fields_ = [("seed", ctypes.c_int), _fields_ = [("seed", ctypes.c_int),
("prompt", ctypes.c_char_p), ("prompt", ctypes.c_char_p),
("memory", ctypes.c_char_p),
("max_context_length", ctypes.c_int), ("max_context_length", ctypes.c_int),
("max_length", ctypes.c_int), ("max_length", ctypes.c_int),
("temperature", ctypes.c_float), ("temperature", ctypes.c_float),
@ -73,7 +74,7 @@ class generation_inputs(ctypes.Structure):
class generation_outputs(ctypes.Structure): class generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int), _fields_ = [("status", ctypes.c_int),
("text", ctypes.c_char * 24576)] ("text", ctypes.c_char * 32768)]
handle = None handle = None
@ -297,11 +298,12 @@ def load_model(model_filename):
ret = handle.load_model(inputs) ret = handle.load_model(inputs)
return ret return ret
def generate(prompt,max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''): def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey=''):
global maxctx, args, currentusergenkey, totalgens global maxctx, args, currentusergenkey, totalgens
inputs = generation_inputs() inputs = generation_inputs()
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs)) outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
inputs.prompt = prompt.encode("UTF-8") inputs.prompt = prompt.encode("UTF-8")
inputs.memory = memory.encode("UTF-8")
if max_length >= max_context_length: if max_length >= max_context_length:
max_length = max_context_length-1 max_length = max_context_length-1
inputs.max_context_length = max_context_length # this will resize the context buffer if changed inputs.max_context_length = max_context_length # this will resize the context buffer if changed
@ -379,7 +381,7 @@ maxhordelen = 256
modelbusy = threading.Lock() modelbusy = threading.Lock()
requestsinqueue = 0 requestsinqueue = 0
defaultport = 5001 defaultport = 5001
KcppVersion = "1.48.1" KcppVersion = "1.49"
showdebug = True showdebug = True
showsamplerwarning = True showsamplerwarning = True
showmaxctxwarning = True showmaxctxwarning = True
@ -474,6 +476,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
return generate( return generate(
prompt=genparams.get('prompt', ""), prompt=genparams.get('prompt', ""),
memory=genparams.get('memory', ""),
max_context_length=genparams.get('max_context_length', maxctx), max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 80), max_length=genparams.get('max_length', 80),
temperature=genparams.get('temperature', 0.7), temperature=genparams.get('temperature', 0.7),