compability with basic_api, change api path to /extra
This commit is contained in:
parent
b4e9e185d3
commit
dee692a63e
1 changed files with 26 additions and 10 deletions
36
koboldcpp.py
36
koboldcpp.py
|
@ -226,7 +226,7 @@ class ServerRequestHandler:
|
||||||
self.embedded_kailite = embedded_kailite
|
self.embedded_kailite = embedded_kailite
|
||||||
|
|
||||||
|
|
||||||
async def generate_text(self, newprompt, genparams):
|
async def generate_text(self, newprompt, genparams, basic_api_flag):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
executor = ThreadPoolExecutor()
|
executor = ThreadPoolExecutor()
|
||||||
|
|
||||||
|
@ -234,6 +234,22 @@ class ServerRequestHandler:
|
||||||
# Reset finished status before generating
|
# Reset finished status before generating
|
||||||
handle.bind_set_stream_finished(False)
|
handle.bind_set_stream_finished(False)
|
||||||
|
|
||||||
|
if basic_api_flag:
|
||||||
|
return generate(
|
||||||
|
prompt=newprompt,
|
||||||
|
max_length=genparams.get('max', 50),
|
||||||
|
temperature=genparams.get('temperature', 0.8),
|
||||||
|
top_k=int(genparams.get('top_k', 120)),
|
||||||
|
top_a=genparams.get('top_a', 0.0),
|
||||||
|
top_p=genparams.get('top_p', 0.85),
|
||||||
|
typical_p=genparams.get('typical', 1.0),
|
||||||
|
tfs=genparams.get('tfs', 1.0),
|
||||||
|
rep_pen=genparams.get('rep_pen', 1.1),
|
||||||
|
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||||
|
seed=genparams.get('sampler_seed', -1),
|
||||||
|
stop_sequence=genparams.get('stop_sequence', [])
|
||||||
|
)
|
||||||
|
|
||||||
return generate(prompt=newprompt,
|
return generate(prompt=newprompt,
|
||||||
max_context_length=genparams.get('max_context_length', maxctx),
|
max_context_length=genparams.get('max_context_length', maxctx),
|
||||||
max_length=genparams.get('max_length', 50),
|
max_length=genparams.get('max_length', 50),
|
||||||
|
@ -251,7 +267,9 @@ class ServerRequestHandler:
|
||||||
|
|
||||||
recvtxt = await loop.run_in_executor(executor, run_blocking)
|
recvtxt = await loop.run_in_executor(executor, run_blocking)
|
||||||
|
|
||||||
res = {"results": [{"text": recvtxt}]}
|
utfprint("\nOutput: " + recvtxt)
|
||||||
|
|
||||||
|
res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return res
|
return res
|
||||||
|
@ -279,20 +297,19 @@ class ServerRequestHandler:
|
||||||
await response.write_eof()
|
await response.write_eof()
|
||||||
await response.force_close()
|
await response.force_close()
|
||||||
|
|
||||||
async def handle_request(self, request, genparams, newprompt, stream_flag):
|
async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag):
|
||||||
tasks = []
|
tasks = []
|
||||||
|
|
||||||
if stream_flag:
|
if stream_flag:
|
||||||
tasks.append(self.handle_sse_stream(request,))
|
tasks.append(self.handle_sse_stream(request,))
|
||||||
|
|
||||||
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams))
|
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag))
|
||||||
tasks.append(generate_task)
|
tasks.append(generate_task)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
if not stream_flag:
|
generate_result = generate_task.result()
|
||||||
generate_result = generate_task.result()
|
return generate_result
|
||||||
return generate_result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
|
@ -344,7 +361,6 @@ class ServerRequestHandler:
|
||||||
kai_api_flag = False
|
kai_api_flag = False
|
||||||
kai_sse_stream_flag = False
|
kai_sse_stream_flag = False
|
||||||
path = request.path.rstrip('/')
|
path = request.path.rstrip('/')
|
||||||
print(request)
|
|
||||||
|
|
||||||
if modelbusy:
|
if modelbusy:
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
|
@ -358,7 +374,7 @@ class ServerRequestHandler:
|
||||||
if path.endswith(('/api/v1/generate', '/api/latest/generate')):
|
if path.endswith(('/api/v1/generate', '/api/latest/generate')):
|
||||||
kai_api_flag = True
|
kai_api_flag = True
|
||||||
|
|
||||||
if path.endswith('/api/v1/generate/stream'):
|
if path.endswith('/api/extra/generate/stream'):
|
||||||
kai_api_flag = True
|
kai_api_flag = True
|
||||||
kai_sse_stream_flag = True
|
kai_sse_stream_flag = True
|
||||||
|
|
||||||
|
@ -378,7 +394,7 @@ class ServerRequestHandler:
|
||||||
fullprompt = genparams.get('text', "")
|
fullprompt = genparams.get('text', "")
|
||||||
newprompt = fullprompt
|
newprompt = fullprompt
|
||||||
|
|
||||||
gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag)
|
gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag)
|
||||||
|
|
||||||
modelbusy = False
|
modelbusy = False
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue