Added SmartContext mode, a way of prompt context manipulation that avoids frequent context recalculation.
This commit is contained in:
parent
ca297c190f
commit
adb4df78d6
6 changed files with 254 additions and 51 deletions
1
expose.h
1
expose.h
|
@ -9,6 +9,7 @@ struct load_model_inputs
|
||||||
const char *model_filename;
|
const char *model_filename;
|
||||||
const int n_parts_overwrite = -1;
|
const int n_parts_overwrite = -1;
|
||||||
const bool use_mmap;
|
const bool use_mmap;
|
||||||
|
const bool use_smartcontext;
|
||||||
const int clblast_info = 0;
|
const int clblast_info = 0;
|
||||||
};
|
};
|
||||||
struct generation_inputs
|
struct generation_inputs
|
||||||
|
|
|
@ -35,6 +35,8 @@ static std::vector<gpt_vocab::id> current_context_tokens;
|
||||||
static size_t mem_per_token = 0;
|
static size_t mem_per_token = 0;
|
||||||
static std::vector<float> logits;
|
static std::vector<float> logits;
|
||||||
|
|
||||||
|
static std::vector<int> smartcontext;
|
||||||
|
|
||||||
inline bool IsNanCheck(float f)
|
inline bool IsNanCheck(float f)
|
||||||
{
|
{
|
||||||
const unsigned int u = *(unsigned int*)&f;
|
const unsigned int u = *(unsigned int*)&f;
|
||||||
|
@ -194,27 +196,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
|
|
||||||
//fast forward the past based on identical tokens, stop once a divergence is noted
|
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, true);
|
||||||
int embd_inp_len = embd_inp.size();
|
|
||||||
for (int i = 0; i < current_context_tokens.size(); ++i)
|
|
||||||
{
|
|
||||||
if (current_context_tokens[i] == embd_inp[i])
|
|
||||||
{
|
|
||||||
n_past += 1;
|
|
||||||
last_n_tokens.push_back(current_context_tokens[i]);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if ((i + 2) >= embd_inp_len)
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past);
|
|
||||||
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past);
|
|
||||||
|
|
||||||
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
||||||
// bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2);
|
// bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2);
|
||||||
|
|
11
koboldcpp.py
11
koboldcpp.py
|
@ -16,6 +16,7 @@ class load_model_inputs(ctypes.Structure):
|
||||||
("model_filename", ctypes.c_char_p),
|
("model_filename", ctypes.c_char_p),
|
||||||
("n_parts_overwrite", ctypes.c_int),
|
("n_parts_overwrite", ctypes.c_int),
|
||||||
("use_mmap", ctypes.c_bool),
|
("use_mmap", ctypes.c_bool),
|
||||||
|
("use_smartcontext", ctypes.c_bool),
|
||||||
("clblast_info", ctypes.c_int)]
|
("clblast_info", ctypes.c_int)]
|
||||||
|
|
||||||
class generation_inputs(ctypes.Structure):
|
class generation_inputs(ctypes.Structure):
|
||||||
|
@ -65,7 +66,7 @@ def init_library():
|
||||||
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
|
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
|
||||||
handle.generate.restype = generation_outputs
|
handle.generate.restype = generation_outputs
|
||||||
|
|
||||||
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False):
|
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False,use_smartcontext=False):
|
||||||
inputs = load_model_inputs()
|
inputs = load_model_inputs()
|
||||||
inputs.model_filename = model_filename.encode("UTF-8")
|
inputs.model_filename = model_filename.encode("UTF-8")
|
||||||
inputs.batch_size = batch_size
|
inputs.batch_size = batch_size
|
||||||
|
@ -74,6 +75,7 @@ def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwr
|
||||||
inputs.n_parts_overwrite = n_parts_overwrite
|
inputs.n_parts_overwrite = n_parts_overwrite
|
||||||
inputs.f16_kv = True
|
inputs.f16_kv = True
|
||||||
inputs.use_mmap = use_mmap
|
inputs.use_mmap = use_mmap
|
||||||
|
inputs.use_smartcontext = use_smartcontext
|
||||||
clblastids = 0
|
clblastids = 0
|
||||||
if args.useclblast:
|
if args.useclblast:
|
||||||
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
|
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
|
||||||
|
@ -383,8 +385,8 @@ def main(args):
|
||||||
|
|
||||||
mdl_nparts = sum(1 for n in range(1, 9) if os.path.exists(f"{ggml_selected_file}.{n}")) + 1
|
mdl_nparts = sum(1 for n in range(1, 9) if os.path.exists(f"{ggml_selected_file}.{n}")) + 1
|
||||||
modelname = os.path.abspath(ggml_selected_file)
|
modelname = os.path.abspath(ggml_selected_file)
|
||||||
print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}]")
|
print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}, SmartContext: {args.smartcontext}]")
|
||||||
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,(not args.nommap))
|
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,(not args.nommap),args.smartcontext)
|
||||||
print("Load Model OK: " + str(loadok))
|
print("Load Model OK: " + str(loadok))
|
||||||
|
|
||||||
if not loadok:
|
if not loadok:
|
||||||
|
@ -413,7 +415,7 @@ def main(args):
|
||||||
RunServerMultiThreaded(args.host, args.port, embedded_kailite)
|
RunServerMultiThreaded(args.host, args.port, embedded_kailite)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print("Welcome to KoboldCpp - Version 1.6") # just update version manually
|
print("Welcome to KoboldCpp - Version 1.7") # just update version manually
|
||||||
parser = argparse.ArgumentParser(description='Kobold llama.cpp server')
|
parser = argparse.ArgumentParser(description='Kobold llama.cpp server')
|
||||||
parser.add_argument("model_file", help="Model file to load", nargs="?")
|
parser.add_argument("model_file", help="Model file to load", nargs="?")
|
||||||
portgroup = parser.add_mutually_exclusive_group() #we want to be backwards compatible with the unnamed positional args
|
portgroup = parser.add_mutually_exclusive_group() #we want to be backwards compatible with the unnamed positional args
|
||||||
|
@ -430,6 +432,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument("--threads", help="Use a custom number of threads if specified. Otherwise, uses an amount based on CPU cores", type=int, default=default_threads)
|
parser.add_argument("--threads", help="Use a custom number of threads if specified. Otherwise, uses an amount based on CPU cores", type=int, default=default_threads)
|
||||||
parser.add_argument("--psutil_set_threads", help="Experimental flag. If set, uses psutils to determine thread count based on physical cores.", action='store_true')
|
parser.add_argument("--psutil_set_threads", help="Experimental flag. If set, uses psutils to determine thread count based on physical cores.", action='store_true')
|
||||||
parser.add_argument("--stream", help="Uses pseudo streaming", action='store_true')
|
parser.add_argument("--stream", help="Uses pseudo streaming", action='store_true')
|
||||||
|
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
|
||||||
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
|
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
|
||||||
parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true')
|
parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true')
|
||||||
compatgroup = parser.add_mutually_exclusive_group()
|
compatgroup = parser.add_mutually_exclusive_group()
|
||||||
|
|
|
@ -31,6 +31,7 @@ static std::string modelname;
|
||||||
static llama_context *ctx;
|
static llama_context *ctx;
|
||||||
static std::vector<llama_token> last_n_tokens;
|
static std::vector<llama_token> last_n_tokens;
|
||||||
static std::vector<llama_token> current_context_tokens;
|
static std::vector<llama_token> current_context_tokens;
|
||||||
|
static std::vector<llama_token> smartcontext;
|
||||||
|
|
||||||
bool llama_load_model(const load_model_inputs inputs, FileFormat in_file_format)
|
bool llama_load_model(const load_model_inputs inputs, FileFormat in_file_format)
|
||||||
{
|
{
|
||||||
|
@ -115,9 +116,10 @@ generation_outputs llama_generate(const generation_inputs inputs, generation_out
|
||||||
}
|
}
|
||||||
|
|
||||||
//truncate to front of the prompt if its too long
|
//truncate to front of the prompt if its too long
|
||||||
if (embd_inp.size() + params.n_predict > params.n_ctx)
|
int32_t nctx = params.n_ctx;
|
||||||
|
if (embd_inp.size() + params.n_predict > nctx)
|
||||||
{
|
{
|
||||||
int offset = embd_inp.size() - params.n_ctx + params.n_predict;
|
int offset = embd_inp.size() - nctx + params.n_predict;
|
||||||
embd_inp = std::vector<llama_token>(embd_inp.begin() + offset, embd_inp.end());
|
embd_inp = std::vector<llama_token>(embd_inp.begin() + offset, embd_inp.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,28 +133,7 @@ generation_outputs llama_generate(const generation_inputs inputs, generation_out
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
n_past = 0;
|
n_past = 0;
|
||||||
|
|
||||||
//fast forward the past based on identical tokens, stop once a divergence is noted
|
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, true);
|
||||||
int embd_inp_len = embd_inp.size();
|
|
||||||
int ctxcs = current_context_tokens.size();
|
|
||||||
for (int i = 0; i < ctxcs; ++i)
|
|
||||||
{
|
|
||||||
if (current_context_tokens[i] == embd_inp[i])
|
|
||||||
{
|
|
||||||
n_past += 1;
|
|
||||||
last_n_tokens.push_back(current_context_tokens[i]);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if ((i + 2) >= embd_inp_len)
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past);
|
|
||||||
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past);
|
|
||||||
|
|
||||||
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
|
||||||
bool blasmode = (embd_inp.size() >= 32 && ggml_cpu_has_blas());
|
bool blasmode = (embd_inp.size() >= 32 && ggml_cpu_has_blas());
|
||||||
|
|
|
@ -28,6 +28,10 @@ double timer_check()
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_tok_vec(std::vector<int> &embd)
|
void print_tok_vec(std::vector<int> &embd)
|
||||||
|
{
|
||||||
|
print_tok_vec(embd,nullptr);
|
||||||
|
}
|
||||||
|
void print_tok_vec(std::vector<int> &embd, std::map<int32_t, std::string> * decoder)
|
||||||
{
|
{
|
||||||
std::cout << "[";
|
std::cout << "[";
|
||||||
bool first = true;
|
bool first = true;
|
||||||
|
@ -38,7 +42,14 @@ void print_tok_vec(std::vector<int> &embd)
|
||||||
std::cout << ',';
|
std::cout << ',';
|
||||||
}
|
}
|
||||||
first = false;
|
first = false;
|
||||||
std::cout << i;
|
if(decoder)
|
||||||
|
{
|
||||||
|
std::cout << (*decoder)[i];
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::cout << i;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::cout << "]\n";
|
std::cout << "]\n";
|
||||||
}
|
}
|
||||||
|
@ -125,4 +136,222 @@ void print_tok_vec(std::vector<float> &embd)
|
||||||
fin.close();
|
fin.close();
|
||||||
|
|
||||||
return fileformat;
|
return fileformat;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ArrStartWith(const std::vector<int> targetArray, const std::vector<int> searchSeq)
|
||||||
|
{
|
||||||
|
int ss = searchSeq.size();
|
||||||
|
if(targetArray.size()<ss)
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for(int i=0;i<ss;++i)
|
||||||
|
{
|
||||||
|
if(targetArray[i]!=searchSeq[i])
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ArrFindIndexOf(const std::vector<int> targetArray, const std::vector<int> searchSeq)
|
||||||
|
{
|
||||||
|
int ss = searchSeq.size();
|
||||||
|
int tas = targetArray.size();
|
||||||
|
if(tas<ss)
|
||||||
|
{
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
for(int i=0;i<tas;++i)
|
||||||
|
{
|
||||||
|
int srch = 0;
|
||||||
|
bool fail = false;
|
||||||
|
for(int srch=0;srch<ss;++srch)
|
||||||
|
{
|
||||||
|
if ((i + srch) >= tas || targetArray[i + srch] != searchSeq[srch])
|
||||||
|
{
|
||||||
|
fail = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(!fail)
|
||||||
|
{
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> LongestCommonSubseq(const std::vector<int> x, const std::vector<int> y)
|
||||||
|
{
|
||||||
|
int m = x.size(), n = y.size();
|
||||||
|
|
||||||
|
//int LCSuff[m+1][n+1];
|
||||||
|
std::vector<std::vector<int>> LCSuff(m+1, std::vector<int>(n+1));
|
||||||
|
|
||||||
|
for (int j = 0; j <= n; j++)
|
||||||
|
LCSuff[0][j] = 0;
|
||||||
|
for (int i = 0; i <= m; i++)
|
||||||
|
LCSuff[i][0] = 0;
|
||||||
|
|
||||||
|
for (int i = 1; i <= m; i++)
|
||||||
|
{
|
||||||
|
for (int j = 1; j <= n; j++)
|
||||||
|
{
|
||||||
|
if (x[i - 1] == y[j - 1])
|
||||||
|
LCSuff[i][j] = LCSuff[i - 1][j - 1] + 1;
|
||||||
|
else
|
||||||
|
LCSuff[i][j] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> longest;
|
||||||
|
for (int i = 1; i <= m; i++)
|
||||||
|
{
|
||||||
|
for (int j = 1; j <= n; j++)
|
||||||
|
{
|
||||||
|
if (LCSuff[i][j] > longest.size())
|
||||||
|
{
|
||||||
|
auto off1 = ((i - LCSuff[i][j] + 1) - 1);
|
||||||
|
auto off2 = off1 + LCSuff[i][j];
|
||||||
|
longest.clear();
|
||||||
|
// std::vector<int>().swap(longest);
|
||||||
|
longest = std::vector<int>(x.begin() + off1, x.begin() + off2);
|
||||||
|
// x.substr((i - LCSuff[i][j] + 1) - 1, LCSuff[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return longest;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ContextFastForward(std::vector<int> ¤t_context_tokens, std::vector<int> &embd_inp,
|
||||||
|
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext, bool useSmartContext)
|
||||||
|
{
|
||||||
|
const int SCTokThreshold = 32; //how many tokens of similarity triggers smartcontext
|
||||||
|
const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext
|
||||||
|
const int SCInpLenThreshold = nctx * 0.6; //how big must the input array be to trigger smartcontext
|
||||||
|
const int SCPastLenThreshold = nctx * 0.5; //how wide of a gap between the fast forwarded past and the present to trigger smart context
|
||||||
|
const float SCTruncationRatio = 0.5; //ratio for how many tokens to fast forward
|
||||||
|
|
||||||
|
// printf("\nORIGINAL CTX:\n");
|
||||||
|
// print_tok_vec(current_context_tokens);
|
||||||
|
// printf("\nORIGINAL EMBD:\n");
|
||||||
|
// print_tok_vec(embd_inp);
|
||||||
|
|
||||||
|
//fast forward the past based on identical tokens, stop once a divergence is noted
|
||||||
|
int embd_inp_len = embd_inp.size();
|
||||||
|
for (int i = 0; i < current_context_tokens.size(); ++i)
|
||||||
|
{
|
||||||
|
if (current_context_tokens[i] == embd_inp[i])
|
||||||
|
{
|
||||||
|
n_past += 1;
|
||||||
|
last_n_tokens.push_back(current_context_tokens[i]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if ((i + 2) >= embd_inp_len)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past);
|
||||||
|
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_past);
|
||||||
|
embd_inp_len = embd_inp.size();
|
||||||
|
|
||||||
|
//smart context mode, detect if we have a shifted context at max length
|
||||||
|
//requirement: previous context was at least nctx/2 longer than current,
|
||||||
|
//mode is on, and current context already maxed.
|
||||||
|
|
||||||
|
// printf("\nconds: %d %d %d\n",current_context_tokens.size() >= nctx*0.8
|
||||||
|
// ,embd_inp_len >= nctx*0.6 ,current_context_tokens.size() - n_past > nctx*0.5);
|
||||||
|
// printf("csiz:%d par:%d eilen:%d np:%d",current_context_tokens.size(), (int)(nctx*0.8),embd_inp_len,n_past);
|
||||||
|
|
||||||
|
if (useSmartContext && smartcontext.size() > 0 && embd_inp_len >= SCInpLenThreshold)
|
||||||
|
{
|
||||||
|
// printf("curfullcontext:\n");
|
||||||
|
// print_tok_vec(current_context_tokens);
|
||||||
|
|
||||||
|
//see if smartcontext is still usable
|
||||||
|
// printf("smartctx:\n");
|
||||||
|
// print_tok_vec(smartcontext);
|
||||||
|
// printf("embinp:\n");
|
||||||
|
// print_tok_vec(embd_inp);
|
||||||
|
auto shared = LongestCommonSubseq(smartcontext, embd_inp);
|
||||||
|
if (shared.size() > SCTokThreshold && ArrStartWith(smartcontext, shared)) //at least 32 tokens in common
|
||||||
|
{
|
||||||
|
int found = ArrFindIndexOf(embd_inp,shared);
|
||||||
|
if(found>=0)
|
||||||
|
{
|
||||||
|
auto trimmed = std::vector<int>(embd_inp.begin() + found, embd_inp.end());
|
||||||
|
embd_inp = trimmed;
|
||||||
|
embd_inp_len = embd_inp.size();
|
||||||
|
// printf("trimmed:\n");
|
||||||
|
// print_tok_vec(embd_inp,&vocab.id_to_token);
|
||||||
|
printf("\n[Reusing Smart Context: %d allowance remaining]", found);
|
||||||
|
|
||||||
|
int old_n_past = n_past;
|
||||||
|
int offset_fix = old_n_past;
|
||||||
|
if (current_context_tokens[n_past] != embd_inp[0])
|
||||||
|
{
|
||||||
|
offset_fix = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = n_past; i < current_context_tokens.size(); ++i)
|
||||||
|
{
|
||||||
|
//printf("\n%s and %s\n",vocab.id_to_token[current_context_tokens[i]].c_str(), vocab.id_to_token[embd_inp[i-offset_fix]].c_str());
|
||||||
|
if (current_context_tokens[i] == embd_inp[i-offset_fix])
|
||||||
|
{
|
||||||
|
n_past += 1;
|
||||||
|
last_n_tokens.push_back(current_context_tokens[i]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if ((i + 2) >= embd_inp_len)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + (n_past-old_n_past));
|
||||||
|
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + (n_past-old_n_past));
|
||||||
|
// printf("np:%d newembinp: \n",n_past);
|
||||||
|
// print_tok_vec(embd_inp);
|
||||||
|
}else{
|
||||||
|
smartcontext.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
smartcontext.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
smartcontext.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(useSmartContext
|
||||||
|
&& smartcontext.size()==0 && current_context_tokens.size() >= SCCtxLenThreshold
|
||||||
|
&& embd_inp_len >= SCInpLenThreshold
|
||||||
|
&& current_context_tokens.size() - n_past > SCPastLenThreshold)
|
||||||
|
{
|
||||||
|
//determine longest common substring after removing start part
|
||||||
|
int shiftamt = embd_inp.size() * SCTruncationRatio;
|
||||||
|
smartcontext = std::vector<int>(embd_inp.begin() + shiftamt, embd_inp.end());
|
||||||
|
printf("\n[New Smart Context Triggered! Buffered Token Allowance: %d]",shiftamt);
|
||||||
|
// printf("smartctx:\n");
|
||||||
|
// print_tok_vec(smartcontext,&vocab.id_to_token);
|
||||||
|
embd_inp = smartcontext;
|
||||||
|
//if max ctx length is exceeded, chop the prompt in half after the start part, and memorize it. The memorized part becomes LCS marker.
|
||||||
|
//when a future prompt comes in, find the LCS again. If LCS > a length and LCS starts with memorized LCS
|
||||||
|
//remove all tokens between start part and start of LCS in new prompt, thus avoiding shift
|
||||||
|
//if LCS not found or mismatched, regenerate. chop new prompt and repeat from step B
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -44,5 +44,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
||||||
void timer_start();
|
void timer_start();
|
||||||
double timer_check();
|
double timer_check();
|
||||||
void print_tok_vec(std::vector<int> &embd);
|
void print_tok_vec(std::vector<int> &embd);
|
||||||
|
void print_tok_vec(std::vector<int> &embd, std::map<int32_t, std::string> * decoder);
|
||||||
void print_tok_vec(std::vector<float> &embd);
|
void print_tok_vec(std::vector<float> &embd);
|
||||||
FileFormat check_file_format(const std::string & fname);
|
std::vector<int> LongestCommonSubseq(const std::vector<int> x, const std::vector<int> y);
|
||||||
|
bool ArrStartWith(const std::vector<int> targetArray, const std::vector<int> searchSeq);
|
||||||
|
int ArrFindIndexOf(const std::vector<int> targetArray, const std::vector<int> searchSeq);
|
||||||
|
|
||||||
|
FileFormat check_file_format(const std::string & fname);
|
||||||
|
void ContextFastForward(std::vector<int> ¤t_context_tokens, std::vector<int> &embd_inp,
|
||||||
|
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext, const bool useSmartContext);
|
Loading…
Add table
Add a link
Reference in a new issue