compability with basic_api, change api path to /extra

This commit is contained in:
SammCheese 2023-06-08 15:56:25 +02:00
parent b4e9e185d3
commit dee692a63e
No known key found for this signature in database
GPG key ID: 28CFE2321A140BA1

View file

@ -226,7 +226,7 @@ class ServerRequestHandler:
self.embedded_kailite = embedded_kailite
async def generate_text(self, newprompt, genparams):
async def generate_text(self, newprompt, genparams, basic_api_flag):
loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor()
@ -234,6 +234,22 @@ class ServerRequestHandler:
# Reset finished status before generating
handle.bind_set_stream_finished(False)
if basic_api_flag:
return generate(
prompt=newprompt,
max_length=genparams.get('max', 50),
temperature=genparams.get('temperature', 0.8),
top_k=int(genparams.get('top_k', 120)),
top_a=genparams.get('top_a', 0.0),
top_p=genparams.get('top_p', 0.85),
typical_p=genparams.get('typical', 1.0),
tfs=genparams.get('tfs', 1.0),
rep_pen=genparams.get('rep_pen', 1.1),
rep_pen_range=genparams.get('rep_pen_range', 128),
seed=genparams.get('sampler_seed', -1),
stop_sequence=genparams.get('stop_sequence', [])
)
return generate(prompt=newprompt,
max_context_length=genparams.get('max_context_length', maxctx),
max_length=genparams.get('max_length', 50),
@ -251,7 +267,9 @@ class ServerRequestHandler:
recvtxt = await loop.run_in_executor(executor, run_blocking)
res = {"results": [{"text": recvtxt}]}
utfprint("\nOutput: " + recvtxt)
res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]}
try:
return res
@ -279,20 +297,19 @@ class ServerRequestHandler:
await response.write_eof()
await response.force_close()
async def handle_request(self, request, genparams, newprompt, stream_flag):
async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag):
tasks = []
if stream_flag:
tasks.append(self.handle_sse_stream(request,))
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams))
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag))
tasks.append(generate_task)
try:
await asyncio.gather(*tasks)
if not stream_flag:
generate_result = generate_task.result()
return generate_result
generate_result = generate_task.result()
return generate_result
except Exception as e:
print(e)
@ -344,7 +361,6 @@ class ServerRequestHandler:
kai_api_flag = False
kai_sse_stream_flag = False
path = request.path.rstrip('/')
print(request)
if modelbusy:
return web.json_response(
@ -358,7 +374,7 @@ class ServerRequestHandler:
if path.endswith(('/api/v1/generate', '/api/latest/generate')):
kai_api_flag = True
if path.endswith('/api/v1/generate/stream'):
if path.endswith('/api/extra/generate/stream'):
kai_api_flag = True
kai_sse_stream_flag = True
@ -378,7 +394,7 @@ class ServerRequestHandler:
fullprompt = genparams.get('text', "")
newprompt = fullprompt
gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag)
gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag)
modelbusy = False