diff --git a/koboldcpp.py b/koboldcpp.py index e8c5836cb..789d81a40 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -381,54 +381,40 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): super().log_message(format, *args) pass - async def generate_text(self, newprompt, genparams, basic_api_flag, stream_flag): + async def generate_text(self, genparams, api_format, stream_flag): def run_blocking(): - if basic_api_flag: - return generate( - prompt=newprompt, - max_length=genparams.get('max', 50), - temperature=genparams.get('temperature', 0.8), - top_k=int(genparams.get('top_k', 120)), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - mirostat=genparams.get('mirostat', 0), - mirostat_tau=genparams.get('mirostat_tau', 5.0), - mirostat_eta=genparams.get('mirostat_eta', 0.1), - sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []), - use_default_badwordsids=genparams.get('use_default_badwordsids', True), - stream_sse=stream_flag, - grammar=genparams.get('grammar', ''), - genkey=genparams.get('genkey', '')) + if api_format==1: + genparams["prompt"] = genparams.get('text', "") + genparams["top_k"] = int(genparams.get('top_k', 120)) + genparams["max_length"]=genparams.get('max', 50) + elif api_format==3: + scaled_rep_pen = genparams.get('presence_penalty', 0.1) + 1 + genparams["max_length"] = genparams.get('max_tokens', 50) + genparams["rep_pen"] = scaled_rep_pen - else: - return generate(prompt=newprompt, - max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 50), - temperature=genparams.get('temperature', 0.8), - top_k=genparams.get('top_k', 120), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - mirostat=genparams.get('mirostat', 0), - mirostat_tau=genparams.get('mirostat_tau', 5.0), - mirostat_eta=genparams.get('mirostat_eta', 0.1), - sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []), - use_default_badwordsids=genparams.get('use_default_badwordsids', True), - stream_sse=stream_flag, - grammar=genparams.get('grammar', ''), - genkey=genparams.get('genkey', '')) + return generate( + prompt=genparams.get('prompt', ""), + max_context_length=genparams.get('max_context_length', maxctx), + max_length=genparams.get('max_length', 80), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 120), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 256), + mirostat=genparams.get('mirostat', 0), + mirostat_tau=genparams.get('mirostat_tau', 5.0), + mirostat_eta=genparams.get('mirostat_eta', 0.1), + sampler_order=genparams.get('sampler_order', [6,0,1,3,4,2,5]), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []), + use_default_badwordsids=genparams.get('use_default_badwordsids', True), + stream_sse=stream_flag, + grammar=genparams.get('grammar', ''), + genkey=genparams.get('genkey', '')) recvtxt = "" if stream_flag: @@ -441,7 +427,13 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if args.debugmode!=-1: utfprint("\nOutput: " + recvtxt) - res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]} + if api_format==1: + res = {"data": {"seqs":[recvtxt]}} + elif api_format==3: + res = {"id": "cmpl-1", "object": "text_completion", "created": 1, "model": "koboldcpp", + "choices": [{"text": recvtxt, "index": 0, "finish_reason": "length"}]} + else: + res = {"results": [{"text": recvtxt}]} try: return res @@ -490,13 +482,13 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): self.close_connection = True - async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag): + async def handle_request(self, genparams, api_format, stream_flag): tasks = [] if stream_flag: tasks.append(self.handle_sse_stream()) - generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag, stream_flag)) + generate_task = asyncio.create_task(self.generate_text(genparams, api_format, stream_flag)) tasks.append(generate_task) try: @@ -568,6 +560,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore") response_body = (json.dumps({"results": [{"text": pendtxtStr}]}).encode()) + elif self.path.endswith('/api/extra/oai/v1/models'): + response_body = (json.dumps({"object":"list","data":[{"id":"koboldcpp","object":"model","created":1,"owned_by":"koboldcpp","permission":[],"root":"koboldcpp"}]}).encode()) elif self.path.endswith(('/api')) or self.path.endswith(('/api/v1')): response_body = (json.dumps({"result":"KoboldCpp partial API reference can be found at https://link.concedo.workers.dev/koboldapi"}).encode()) @@ -652,20 +646,24 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): requestsinqueue = (requestsinqueue - 1) if requestsinqueue>0 else 0 try: - basic_api_flag = False - kai_api_flag = False kai_sse_stream_flag = False + + api_format = 0 #1=basic,2=kai,3=oai + if self.path.endswith('/request'): - basic_api_flag = True + api_format = 1 if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): - kai_api_flag = True + api_format = 2 if self.path.endswith('/api/extra/generate/stream'): - kai_api_flag = True + api_format = 2 kai_sse_stream_flag = True - if basic_api_flag or kai_api_flag: + if self.path.endswith('/api/extra/oai/v1/completions'): + api_format = 3 + + if api_format>0: genparams = None try: genparams = json.loads(body) @@ -676,13 +674,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): if args.debugmode!=-1: utfprint("\nInput: " + json.dumps(genparams)) - if kai_api_flag: - fullprompt = genparams.get('prompt', "") - else: - fullprompt = genparams.get('text', "") - newprompt = fullprompt - - gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag)) + gen = asyncio.run(self.handle_request(genparams, api_format, kai_sse_stream_flag)) try: # Headers are already sent when streaming