Merge remote-tracking branch 'sammcheese/sammcheese/tokenstreaming' into concedo_experimental

This commit is contained in:
Concedo 2023-06-09 20:41:02 +08:00
commit b92f9fe3a2
4 changed files with 159 additions and 48 deletions

View file

@ -23,6 +23,10 @@
std::string executable_path = ""; std::string executable_path = "";
std::string lora_filename = ""; std::string lora_filename = "";
bool generation_finished;
std::vector<std::string> generated_tokens;
extern "C" extern "C"
{ {
@ -207,4 +211,18 @@ extern "C"
{ {
return gpttype_generate(inputs, output); return gpttype_generate(inputs, output);
} }
const char* new_token(int idx) {
if (generated_tokens.size() <= idx || idx < 0) return nullptr;
return generated_tokens[idx].c_str();
}
int get_stream_count() {
return generated_tokens.size();
}
bool has_finished() {
return generation_finished;
}
} }

View file

@ -18,6 +18,7 @@ struct load_model_inputs
const int clblast_info = 0; const int clblast_info = 0;
const int blasbatchsize = 512; const int blasbatchsize = 512;
const bool debugmode; const bool debugmode;
const bool stream_sse;
const int forceversion = 0; const int forceversion = 0;
const int gpulayers = 0; const int gpulayers = 0;
}; };
@ -48,3 +49,6 @@ struct generation_outputs
extern std::string executable_path; extern std::string executable_path;
extern std::string lora_filename; extern std::string lora_filename;
extern std::vector<std::string> generated_tokens;
extern bool generation_finished;

View file

@ -63,6 +63,7 @@ static bool useSmartContext = false;
static bool unbanTokens = false; static bool unbanTokens = false;
static int blasbatchsize = 512; static int blasbatchsize = 512;
static bool debugmode = false; static bool debugmode = false;
static bool stream_sse = true;
static std::string modelname; static std::string modelname;
static std::vector<gpt_vocab::id> last_n_tokens; static std::vector<gpt_vocab::id> last_n_tokens;
static std::vector<gpt_vocab::id> current_context_tokens; static std::vector<gpt_vocab::id> current_context_tokens;
@ -735,6 +736,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
params.n_batch = n_batch; params.n_batch = n_batch;
params.n_threads = n_threads; params.n_threads = n_threads;
generation_finished = false; // Set current generation status
generated_tokens.clear(); // New Generation, new tokens
if (params.repeat_last_n < 1) if (params.repeat_last_n < 1)
{ {
params.repeat_last_n = 1; params.repeat_last_n = 1;
@ -1038,6 +1042,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
fprintf(stderr, "Failed to predict\n"); fprintf(stderr, "Failed to predict\n");
snprintf(output.text, sizeof(output.text), "%s", ""); snprintf(output.text, sizeof(output.text), "%s", "");
output.status = 0; output.status = 0;
generation_finished = true;
return output; return output;
} }
} }
@ -1147,7 +1152,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
for (auto id : embd) for (auto id : embd)
{ {
concat_output += FileFormatTokenizeID(id,file_format); std::string tokenizedstr = FileFormatTokenizeID(id, file_format);
if (stream_sse)
{
generated_tokens.push_back(tokenizedstr);
}
concat_output += tokenizedstr;
} }
if (startedsampling) if (startedsampling)
@ -1214,6 +1225,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2));
fflush(stdout); fflush(stdout);
output.status = 1; output.status = 1;
generation_finished = true;
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());
return output; return output;

View file

@ -5,7 +5,8 @@
import ctypes import ctypes
import os import os
import argparse import argparse
import json, http.server, threading, socket, sys, time import json, sys, http.server, time, asyncio, socket, threading
from concurrent.futures import ThreadPoolExecutor
stop_token_max = 10 stop_token_max = 10
@ -134,6 +135,10 @@ def init_library():
handle.load_model.restype = ctypes.c_bool handle.load_model.restype = ctypes.c_bool
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
handle.generate.restype = generation_outputs handle.generate.restype = generation_outputs
handle.new_token.restype = ctypes.c_char_p
handle.new_token.argtypes = [ctypes.c_int]
handle.get_stream_count.restype = ctypes.c_int
handle.has_finished.restype = ctypes.c_bool
def load_model(model_filename): def load_model(model_filename):
inputs = load_model_inputs() inputs = load_model_inputs()
@ -183,7 +188,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
else: else:
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0 inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
inputs.seed = seed inputs.seed = seed
for n in range(0,stop_token_max): for n in range(stop_token_max):
if not stop_sequence or n >= len(stop_sequence): if not stop_sequence or n >= len(stop_sequence):
inputs.stop_sequence[n] = "".encode("UTF-8") inputs.stop_sequence[n] = "".encode("UTF-8")
else: else:
@ -224,8 +229,106 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
async def generate_text(self, newprompt, genparams, basic_api_flag):
loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor()
def run_blocking():
if basic_api_flag:
return generate(
prompt=newprompt,
max_length=genparams.get('max', 50),
temperature=genparams.get('temperature', 0.8),
top_k=int(genparams.get('top_k', 120)),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
else:
return generate(prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50),
temperature=genparams.get('temperature', 0.8),
top_k=genparams.get('top_k', 120),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
recvtxt = await loop.run_in_executor(executor, run_blocking)
utfprint("\nOutput: " + recvtxt)
res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]}
try:
return res
except Exception as e:
print(f"Generate: Error while generating: {e}")
async def send_sse_event(self, event, data):
self.wfile.write(f'event: {event}\n'.encode())
self.wfile.write(f'data: {data}\n\n'.encode())
async def handle_sse_stream(self):
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "keep-alive")
self.end_headers()
current_token = 0;
while not handle.has_finished():
if current_token < handle.get_stream_count():
token = handle.new_token(current_token)
if token is None: # Token isnt ready yet, received nullpointer
continue
current_token += 1
tokenStr = ctypes.string_at(token).decode('utf-8')
event_data = {"token": tokenStr}
event_str = json.dumps(event_data)
await self.send_sse_event("message", event_str)
await asyncio.sleep(0)
# Implement connection closing here
async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag):
tasks = []
if stream_flag:
tasks.append(self.handle_sse_stream())
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag))
tasks.append(generate_task)
try:
await asyncio.gather(*tasks)
generate_result = generate_task.result()
return generate_result
except Exception as e:
print(e)
def do_GET(self): def do_GET(self):
global maxctx, maxlen, friendlymodelname, KcppVersion global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock
self.path = self.path.rstrip('/') self.path = self.path.rstrip('/')
response_body = None response_body = None
@ -286,8 +389,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
body = self.rfile.read(content_length) body = self.rfile.read(content_length)
basic_api_flag = False basic_api_flag = False
kai_api_flag = False kai_api_flag = False
kai_sse_stream_flag = False
self.path = self.path.rstrip('/') self.path = self.path.rstrip('/')
if modelbusy: if modelbusy:
self.send_response(503) self.send_response(503)
self.end_headers() self.end_headers()
@ -303,72 +408,44 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): if self.path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True kai_api_flag = True
if self.path.endswith('/api/extra/generate/stream'):
kai_api_flag = True
kai_sse_stream_flag = True
if basic_api_flag or kai_api_flag: if basic_api_flag or kai_api_flag:
genparams = None genparams = None
try: try:
genparams = json.loads(body) genparams = json.loads(body)
except ValueError as e: except ValueError as e:
self.send_response(503) return self.send_response(503)
self.end_headers()
return
utfprint("\nInput: " + json.dumps(genparams)) utfprint("\nInput: " + json.dumps(genparams))
modelbusy = True modelbusy = True
if kai_api_flag: if kai_api_flag:
fullprompt = genparams.get('prompt', "") fullprompt = genparams.get('prompt', "")
else: else:
fullprompt = genparams.get('text', "") fullprompt = genparams.get('text', "")
newprompt = fullprompt newprompt = fullprompt
recvtxt = "" gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag))
res = {}
if kai_api_flag:
recvtxt = generate(
prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50),
temperature=genparams.get('temperature', 0.8),
top_k=int(genparams.get('top_k', 120)),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
utfprint("\nOutput: " + recvtxt)
res = {"results": [{"text": recvtxt}]}
else:
recvtxt = generate(
prompt=newprompt,
max_length=genparams.get('max', 50),
temperature=genparams.get('temperature', 0.8),
top_k=int(genparams.get('top_k', 120)),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
utfprint("\nOutput: " + recvtxt)
res = {"data": {"seqs":[recvtxt]}}
try: try:
self.send_response(200) self.send_response(200)
self.end_headers() self.end_headers()
self.wfile.write(json.dumps(res).encode()) self.wfile.write(json.dumps(gen).encode())
except: except:
print("Generate: The response could not be sent, maybe connection was terminated?") print("Generate: The response could not be sent, maybe connection was terminated?")
modelbusy = False modelbusy = False
return return
self.send_response(404) self.send_response(404)
self.end_headers() self.end_headers()
def do_OPTIONS(self): def do_OPTIONS(self):
self.send_response(200) self.send_response(200)
self.end_headers() self.end_headers()
@ -382,10 +459,11 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
self.send_header('Access-Control-Allow-Methods', '*') self.send_header('Access-Control-Allow-Methods', '*')
self.send_header('Access-Control-Allow-Headers', '*') self.send_header('Access-Control-Allow-Headers', '*')
if "/api" in self.path: if "/api" in self.path:
if self.path.endswith("/stream"):
self.send_header('Content-type', 'text/event-stream')
self.send_header('Content-type', 'application/json') self.send_header('Content-type', 'application/json')
else: else:
self.send_header('Content-type', 'text/html') self.send_header('Content-type', 'text/html')
return super(ServerRequestHandler, self).end_headers() return super(ServerRequestHandler, self).end_headers()
@ -500,7 +578,6 @@ def show_gui():
unbantokens = tk.IntVar() unbantokens = tk.IntVar()
highpriority = tk.IntVar() highpriority = tk.IntVar()
disablemmap = tk.IntVar() disablemmap = tk.IntVar()
frameD = tk.Frame(root) frameD = tk.Frame(root)
tk.Checkbutton(frameD, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) tk.Checkbutton(frameD, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0)
tk.Checkbutton(frameD, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) tk.Checkbutton(frameD, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1)
@ -688,7 +765,7 @@ def main(args):
except: except:
print("--launch was set, but could not launch web browser automatically.") print("--launch was set, but could not launch web browser automatically.")
print(f"Please connect to custom endpoint at {epurl}") print(f"Please connect to custom endpoint at {epurl}")
RunServerMultiThreaded(args.host, args.port, embedded_kailite) asyncio.run(RunServerMultiThreaded(args.host, args.port, embedded_kailite))
if __name__ == '__main__': if __name__ == '__main__':
print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually