diff --git a/koboldcpp.py b/koboldcpp.py index bfb534c13..7320b2ce0 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -226,7 +226,7 @@ class ServerRequestHandler: self.embedded_kailite = embedded_kailite - async def generate_text(self, newprompt, genparams): + async def generate_text(self, newprompt, genparams, basic_api_flag): loop = asyncio.get_event_loop() executor = ThreadPoolExecutor() @@ -234,6 +234,22 @@ class ServerRequestHandler: # Reset finished status before generating handle.bind_set_stream_finished(False) + if basic_api_flag: + return generate( + prompt=newprompt, + max_length=genparams.get('max', 50), + temperature=genparams.get('temperature', 0.8), + top_k=int(genparams.get('top_k', 120)), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) + return generate(prompt=newprompt, max_context_length=genparams.get('max_context_length', maxctx), max_length=genparams.get('max_length', 50), @@ -251,7 +267,9 @@ class ServerRequestHandler: recvtxt = await loop.run_in_executor(executor, run_blocking) - res = {"results": [{"text": recvtxt}]} + utfprint("\nOutput: " + recvtxt) + + res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]} try: return res @@ -279,20 +297,19 @@ class ServerRequestHandler: await response.write_eof() await response.force_close() - async def handle_request(self, request, genparams, newprompt, stream_flag): + async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag): tasks = [] if stream_flag: tasks.append(self.handle_sse_stream(request,)) - generate_task = asyncio.create_task(self.generate_text(newprompt, genparams)) + generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag)) tasks.append(generate_task) try: await asyncio.gather(*tasks) - if not stream_flag: - generate_result = generate_task.result() - return generate_result + generate_result = generate_task.result() + return generate_result except Exception as e: print(e) @@ -344,7 +361,6 @@ class ServerRequestHandler: kai_api_flag = False kai_sse_stream_flag = False path = request.path.rstrip('/') - print(request) if modelbusy: return web.json_response( @@ -358,7 +374,7 @@ class ServerRequestHandler: if path.endswith(('/api/v1/generate', '/api/latest/generate')): kai_api_flag = True - if path.endswith('/api/v1/generate/stream'): + if path.endswith('/api/extra/generate/stream'): kai_api_flag = True kai_sse_stream_flag = True @@ -378,7 +394,7 @@ class ServerRequestHandler: fullprompt = genparams.get('text', "") newprompt = fullprompt - gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag) + gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag) modelbusy = False