compability with basic_api, change api path to /extra

This commit is contained in:
SammCheese 2023-06-08 15:56:25 +02:00
parent b4e9e185d3
commit dee692a63e
No known key found for this signature in database
GPG key ID: 28CFE2321A140BA1

View file

@ -226,7 +226,7 @@ class ServerRequestHandler:
self.embedded_kailite = embedded_kailite self.embedded_kailite = embedded_kailite
async def generate_text(self, newprompt, genparams): async def generate_text(self, newprompt, genparams, basic_api_flag):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
@ -234,6 +234,22 @@ class ServerRequestHandler:
# Reset finished status before generating # Reset finished status before generating
handle.bind_set_stream_finished(False) handle.bind_set_stream_finished(False)
if basic_api_flag:
return generate(
prompt=newprompt,
max_length=genparams.get('max', 50),
temperature=genparams.get('temperature', 0.8),
top_k=int(genparams.get('top_k', 120)),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
return generate(prompt=newprompt, return generate(prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx), max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50), max_length=genparams.get('max_length', 50),
@ -251,7 +267,9 @@ class ServerRequestHandler:
recvtxt = await loop.run_in_executor(executor, run_blocking) recvtxt = await loop.run_in_executor(executor, run_blocking)
res = {"results": [{"text": recvtxt}]} utfprint("\nOutput: " + recvtxt)
res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]}
try: try:
return res return res
@ -279,20 +297,19 @@ class ServerRequestHandler:
await response.write_eof() await response.write_eof()
await response.force_close() await response.force_close()
async def handle_request(self, request, genparams, newprompt, stream_flag): async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag):
tasks = [] tasks = []
if stream_flag: if stream_flag:
tasks.append(self.handle_sse_stream(request,)) tasks.append(self.handle_sse_stream(request,))
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams)) generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag))
tasks.append(generate_task) tasks.append(generate_task)
try: try:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
if not stream_flag: generate_result = generate_task.result()
generate_result = generate_task.result() return generate_result
return generate_result
except Exception as e: except Exception as e:
print(e) print(e)
@ -344,7 +361,6 @@ class ServerRequestHandler:
kai_api_flag = False kai_api_flag = False
kai_sse_stream_flag = False kai_sse_stream_flag = False
path = request.path.rstrip('/') path = request.path.rstrip('/')
print(request)
if modelbusy: if modelbusy:
return web.json_response( return web.json_response(
@ -358,7 +374,7 @@ class ServerRequestHandler:
if path.endswith(('/api/v1/generate', '/api/latest/generate')): if path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True kai_api_flag = True
if path.endswith('/api/v1/generate/stream'): if path.endswith('/api/extra/generate/stream'):
kai_api_flag = True kai_api_flag = True
kai_sse_stream_flag = True kai_sse_stream_flag = True
@ -378,7 +394,7 @@ class ServerRequestHandler:
fullprompt = genparams.get('text', "") fullprompt = genparams.get('text', "")
newprompt = fullprompt newprompt = fullprompt
gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag) gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag)
modelbusy = False modelbusy = False