compability with basic_api, change api path to /extra
This commit is contained in:
parent
b4e9e185d3
commit
dee692a63e
1 changed files with 26 additions and 10 deletions
32
koboldcpp.py
32
koboldcpp.py
|
@ -226,7 +226,7 @@ class ServerRequestHandler:
|
|||
self.embedded_kailite = embedded_kailite
|
||||
|
||||
|
||||
async def generate_text(self, newprompt, genparams):
|
||||
async def generate_text(self, newprompt, genparams, basic_api_flag):
|
||||
loop = asyncio.get_event_loop()
|
||||
executor = ThreadPoolExecutor()
|
||||
|
||||
|
@ -234,6 +234,22 @@ class ServerRequestHandler:
|
|||
# Reset finished status before generating
|
||||
handle.bind_set_stream_finished(False)
|
||||
|
||||
if basic_api_flag:
|
||||
return generate(
|
||||
prompt=newprompt,
|
||||
max_length=genparams.get('max', 50),
|
||||
temperature=genparams.get('temperature', 0.8),
|
||||
top_k=int(genparams.get('top_k', 120)),
|
||||
top_a=genparams.get('top_a', 0.0),
|
||||
top_p=genparams.get('top_p', 0.85),
|
||||
typical_p=genparams.get('typical', 1.0),
|
||||
tfs=genparams.get('tfs', 1.0),
|
||||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||
seed=genparams.get('sampler_seed', -1),
|
||||
stop_sequence=genparams.get('stop_sequence', [])
|
||||
)
|
||||
|
||||
return generate(prompt=newprompt,
|
||||
max_context_length=genparams.get('max_context_length', maxctx),
|
||||
max_length=genparams.get('max_length', 50),
|
||||
|
@ -251,7 +267,9 @@ class ServerRequestHandler:
|
|||
|
||||
recvtxt = await loop.run_in_executor(executor, run_blocking)
|
||||
|
||||
res = {"results": [{"text": recvtxt}]}
|
||||
utfprint("\nOutput: " + recvtxt)
|
||||
|
||||
res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]}
|
||||
|
||||
try:
|
||||
return res
|
||||
|
@ -279,18 +297,17 @@ class ServerRequestHandler:
|
|||
await response.write_eof()
|
||||
await response.force_close()
|
||||
|
||||
async def handle_request(self, request, genparams, newprompt, stream_flag):
|
||||
async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag):
|
||||
tasks = []
|
||||
|
||||
if stream_flag:
|
||||
tasks.append(self.handle_sse_stream(request,))
|
||||
|
||||
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams))
|
||||
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag))
|
||||
tasks.append(generate_task)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
if not stream_flag:
|
||||
generate_result = generate_task.result()
|
||||
return generate_result
|
||||
except Exception as e:
|
||||
|
@ -344,7 +361,6 @@ class ServerRequestHandler:
|
|||
kai_api_flag = False
|
||||
kai_sse_stream_flag = False
|
||||
path = request.path.rstrip('/')
|
||||
print(request)
|
||||
|
||||
if modelbusy:
|
||||
return web.json_response(
|
||||
|
@ -358,7 +374,7 @@ class ServerRequestHandler:
|
|||
if path.endswith(('/api/v1/generate', '/api/latest/generate')):
|
||||
kai_api_flag = True
|
||||
|
||||
if path.endswith('/api/v1/generate/stream'):
|
||||
if path.endswith('/api/extra/generate/stream'):
|
||||
kai_api_flag = True
|
||||
kai_sse_stream_flag = True
|
||||
|
||||
|
@ -378,7 +394,7 @@ class ServerRequestHandler:
|
|||
fullprompt = genparams.get('text', "")
|
||||
newprompt = fullprompt
|
||||
|
||||
gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag)
|
||||
gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag)
|
||||
|
||||
modelbusy = False
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue