From 97971291e9e5d05c428d4c0f0cb8f956e36a63c5 Mon Sep 17 00:00:00 2001 From: SammCheese Date: Wed, 7 Jun 2023 00:48:00 +0200 Subject: [PATCH 1/8] draft: token streaming --- expose.cpp | 32 +++++ expose.h | 4 + gpttype_adapter.cpp | 11 +- koboldcpp.py | 306 ++++++++++++++++++++------------------------ 4 files changed, 183 insertions(+), 170 deletions(-) diff --git a/expose.cpp b/expose.cpp index 6388bed88..283438a4b 100644 --- a/expose.cpp +++ b/expose.cpp @@ -23,6 +23,11 @@ std::string executable_path = ""; std::string lora_filename = ""; + +static std::string current_token = ""; +static bool new_token_available = false; +static bool finished_stream = false; + extern "C" { @@ -205,6 +210,33 @@ extern "C" generation_outputs generate(const generation_inputs inputs, generation_outputs &output) { + finished_stream = false; return gpttype_generate(inputs, output); } + + + const char* new_token() { + if (new_token_available) { + new_token_available = false; + return current_token.c_str(); + } + return nullptr; + } + + bool is_locked() { + return !new_token_available; + } + + bool has_finished() { + return finished_stream; + } +} + +void receive_current_token(std::string token) { + current_token = token; + new_token_available = true; +} + +void set_stream_finished() { + finished_stream = true; } diff --git a/expose.h b/expose.h index cb475f141..dac7fffe0 100644 --- a/expose.h +++ b/expose.h @@ -18,6 +18,7 @@ struct load_model_inputs const int clblast_info = 0; const int blasbatchsize = 512; const bool debugmode; + const bool stream_sse; const int forceversion = 0; const int gpulayers = 0; }; @@ -48,3 +49,6 @@ struct generation_outputs extern std::string executable_path; extern std::string lora_filename; + +extern void receive_current_token(std::string token); +extern void set_stream_finished(); diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index d7c334c00..b49fd95f7 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -63,6 +63,7 @@ static bool useSmartContext = false; static bool unbanTokens = false; static int blasbatchsize = 512; static bool debugmode = false; +static bool stream_sse = true; static std::string modelname; static std::vector last_n_tokens; static std::vector current_context_tokens; @@ -1040,6 +1041,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o fprintf(stderr, "Failed to predict\n"); snprintf(output.text, sizeof(output.text), "%s", ""); output.status = 0; + set_stream_finished(); return output; } } @@ -1149,7 +1151,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o for (auto id : embd) { - concat_output += FileFormatTokenizeID(id,file_format); + std::string tokenizedstr = FileFormatTokenizeID(id, file_format); + + if (stream_sse) + { + receive_current_token(tokenizedstr); + } + concat_output += tokenizedstr; } if (startedsampling) @@ -1216,6 +1224,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); fflush(stdout); output.status = 1; + set_stream_finished(); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); return output; diff --git a/koboldcpp.py b/koboldcpp.py index cf8847bc8..553460057 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -5,7 +5,8 @@ import ctypes import os import argparse -import json, http.server, threading, socket, sys, time +import json, http.server, threading, socket, sys, time, asyncio +from aiohttp import web stop_token_max = 10 @@ -134,6 +135,7 @@ def init_library(): handle.load_model.restype = ctypes.c_bool handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.restype = generation_outputs + handle.new_token.restype = ctypes.c_char_p def load_model(model_filename): inputs = load_model_inputs() @@ -212,105 +214,145 @@ modelbusy = False defaultport = 5001 KcppVersion = "1.29" -class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): +class ServerRequestHandler: sys_version = "" server_version = "ConcedoLlamaForKoboldServer" + app = web.Application() def __init__(self, addr, port, embedded_kailite): self.addr = addr self.port = port self.embedded_kailite = embedded_kailite - def __call__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + async def send_sse_event(self, response, event, data): + await response.write(f'event: {event}\n'.encode()) + await response.write(f'data: {data}\n\n'.encode()) - def do_GET(self): - global maxctx, maxlen, friendlymodelname, KcppVersion - self.path = self.path.rstrip('/') - response_body = None - if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without / - if args.stream and not "streaming=1" in self.path: - self.path = self.path.replace("streaming=0","") - if self.path.startswith(('/?','?')): - self.path += "&streaming=1" + async def handle_sse_stream(self, request): + response = web.StreamResponse(headers={"Content-Type": "text/event-stream"}) + await response.prepare(request) + + stream_finished = False + + while True: + if handle.has_finished(): + stream_finished = True + if not handle.is_locked(): + token = ctypes.string_at(handle.new_token()).decode('utf-8') + event_data = {"finished": stream_finished, "token": token} + event_str = f"data: {json.dumps(event_data)}" + await self.send_sse_event(response, "message", event_str) + print(token) + print(event_data) + if stream_finished: + break + + async def generate_text(self, newprompt, genparams): + recvtxt = generate( + prompt=newprompt, + max_context_length=genparams.get('max_context_length', maxctx), + max_length=genparams.get('max_length', 50), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 120), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) + res = {"results": [{"text": recvtxt}]} + + try: + return web.json_response(res) + except: + print("Generate: The response could not be sent, maybe the connection was terminated?") + + async def handle_request(self, request, genparams, newprompt, stream_flag): + if stream_flag: + self.handle_sse_stream(request) + # RUN THESE CONCURRENTLY WITHOUT BLOCKING EACHOTHER + self.generate_text(newprompt, genparams) + + + async def handle_get(self, request): + global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock + path = request.path.rstrip('/') + + if path in ["", "/?"] or path.startswith(('/?','?')): + if args.stream and not "streaming=1" in path: + path = path.replace("streaming=0","") + if path.startswith(('/?','?')): + path += "&streaming=1" else: - self.path = self.path + "?streaming=1" - self.send_response(302) - self.send_header("Location", self.path) - self.end_headers() - print("Force redirect to streaming mode, as --stream is set.") - return None + path = path + "?streaming=1" + raise web.HTTPFound(path) if self.embedded_kailite is None: - response_body = (f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.").encode() + return web.Response( + text="Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect." + ) else: - response_body = self.embedded_kailite + return web.Response(body=self.embedded_kailite) - elif self.path.endswith(('/api/v1/model', '/api/latest/model')): - response_body = (json.dumps({'result': friendlymodelname }).encode()) + elif path.endswith(('/api/v1/model', '/api/latest/model')): + return web.json_response({'result': friendlymodelname}) - elif self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): - response_body = (json.dumps({"value": maxlen}).encode()) + elif path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): + return web.json_response({"value": maxlen}) - elif self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): - response_body = (json.dumps({"value": maxctx}).encode()) + elif path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): + return web.json_response({"value": maxctx}) - elif self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): - response_body = (json.dumps({"value":""}).encode()) + elif path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): + return web.json_response({"value": ""}) - elif self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): - response_body = (json.dumps({"values": []}).encode()) + elif path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): + return web.json_response({"values": []}) - elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')): - response_body = (json.dumps({"result":"1.2.2"}).encode()) + elif path.endswith(('/api/v1/info/version', '/api/latest/info/version')): + return web.json_response({"result": "1.2.2"}) - elif self.path.endswith(('/api/extra/version')): - response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode()) + elif path.endswith(('/api/extra/version')): + return web.json_response({"result": "KoboldCpp", "version": KcppVersion}) - if response_body is None: - self.send_response(404) - self.end_headers() - rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.' - self.wfile.write(rp.encode()) - else: - self.send_response(200) - self.send_header('Content-Length', str(len(response_body))) - self.end_headers() - self.wfile.write(response_body) - return + return web.Response(status=404, text="Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.") - def do_POST(self): + async def handle_post(self, request): global modelbusy - content_length = int(self.headers['Content-Length']) - body = self.rfile.read(content_length) + body = await request.content.read() basic_api_flag = False kai_api_flag = False - self.path = self.path.rstrip('/') + kai_stream_flag = True + path = request.path.rstrip('/') + print(request) if modelbusy: - self.send_response(503) - self.end_headers() - self.wfile.write(json.dumps({"detail": { - "msg": "Server is busy; please try again later.", - "type": "service_unavailable", - }}).encode()) - return + return web.json_response( + {"detail": {"msg": "Server is busy; please try again later.", "type": "service_unavailable"}}, + status=503, + ) - if self.path.endswith('/request'): + if path.endswith('/request'): basic_api_flag = True - if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): + if path.endswith(('/api/v1/generate', '/api/latest/generate')): kai_api_flag = True + if path.endswith('/api/v1/generate/stream'): + kai_api_flag = True + kai_stream_flag = True + if basic_api_flag or kai_api_flag: genparams = None try: genparams = json.loads(body) except ValueError as e: - self.send_response(503) - self.end_headers() - return + return web.Response(status=503) + utfprint("\nInput: " + json.dumps(genparams)) modelbusy = True @@ -320,115 +362,41 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): fullprompt = genparams.get('text', "") newprompt = fullprompt - recvtxt = "" - res = {} - if kai_api_flag: - recvtxt = generate( - prompt=newprompt, - max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 50), - temperature=genparams.get('temperature', 0.8), - top_k=int(genparams.get('top_k', 120)), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) - utfprint("\nOutput: " + recvtxt) - res = {"results": [{"text": recvtxt}]} - else: - recvtxt = generate( - prompt=newprompt, - max_length=genparams.get('max', 50), - temperature=genparams.get('temperature', 0.8), - top_k=int(genparams.get('top_k', 120)), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) - utfprint("\nOutput: " + recvtxt) - res = {"data": {"seqs":[recvtxt]}} + await self.handle_request(request, genparams, newprompt, kai_stream_flag) - try: - self.send_response(200) - self.end_headers() - self.wfile.write(json.dumps(res).encode()) - except: - print("Generate: The response could not be sent, maybe connection was terminated?") modelbusy = False - return - self.send_response(404) - self.end_headers() + return web.Response() - def do_OPTIONS(self): - self.send_response(200) - self.end_headers() + return web.Response(status=404) - def do_HEAD(self): - self.send_response(200) - self.end_headers() + async def handle_options(self): + return web.Response() - def end_headers(self): - self.send_header('Access-Control-Allow-Origin', '*') - self.send_header('Access-Control-Allow-Methods', '*') - self.send_header('Access-Control-Allow-Headers', '*') - if "/api" in self.path: - self.send_header('Content-type', 'application/json') - else: - self.send_header('Content-type', 'text/html') + async def handle_head(self): + return web.Response() - return super(ServerRequestHandler, self).end_headers() + async def start_server(self): + self.app.router.add_route('GET', '/{tail:.*}', self.handle_get) + self.app.router.add_route('POST', '/{tail:.*}', self.handle_post) + self.app.router.add_route('OPTIONS', '/', self.handle_options) + self.app.router.add_route('HEAD', '/', self.handle_head) -def RunServerMultiThreaded(addr, port, embedded_kailite = None): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((addr, port)) - sock.listen(5) + runner = web.AppRunner(self.app) + await runner.setup() + site = web.TCPSite(runner, self.addr, self.port) + await site.start() - class Thread(threading.Thread): - def __init__(self, i): - threading.Thread.__init__(self) - self.i = i - self.daemon = True - self.start() - - def run(self): - handler = ServerRequestHandler(addr, port, embedded_kailite) - with http.server.HTTPServer((addr, port), handler, False) as self.httpd: - try: - self.httpd.socket = sock - self.httpd.server_bind = self.server_close = lambda self: None - self.httpd.serve_forever() - except (KeyboardInterrupt,SystemExit): - self.httpd.server_close() - sys.exit(0) - finally: - self.httpd.server_close() - sys.exit(0) - def stop(self): - self.httpd.server_close() - - numThreads = 6 - threadArr = [] - for i in range(numThreads): - threadArr.append(Thread(i)) - while 1: + # Keep Alive try: - time.sleep(10) + while True: + await asyncio.sleep(3600) except KeyboardInterrupt: - for i in range(numThreads): - threadArr[i].stop() - sys.exit(0) + await runner.cleanup() + +async def run_server(addr, port, embedded_kailite=None): + handler = ServerRequestHandler(addr, port, embedded_kailite) + await handler.start_server() def show_gui(): @@ -500,15 +468,15 @@ def show_gui(): unbantokens = tk.IntVar() highpriority = tk.IntVar() disablemmap = tk.IntVar() + frm3 = tk.Frame(root) + tk.Checkbutton(frm3, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) + tk.Checkbutton(frm3, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) + tk.Checkbutton(frm3, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) + tk.Checkbutton(frm3, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) + tk.Checkbutton(frm3, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) + tk.Checkbutton(frm3, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) - frameD = tk.Frame(root) - tk.Checkbutton(frameD, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) - tk.Checkbutton(frameD, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) - tk.Checkbutton(frameD, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) - tk.Checkbutton(frameD, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) - tk.Checkbutton(frameD, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) - tk.Checkbutton(frameD, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) - frameD.grid(row=5,column=0,pady=4) + frm3.grid(row=5,column=0,pady=4) # Create button, it will change label text tk.Button( root , text = "Launch", font = ("Impact", 18), bg='#54FA9B', command = guilaunch ).grid(row=6,column=0) @@ -688,7 +656,7 @@ def main(args): except: print("--launch was set, but could not launch web browser automatically.") print(f"Please connect to custom endpoint at {epurl}") - RunServerMultiThreaded(args.host, args.port, embedded_kailite) + asyncio.run(run_server(args.host, args.port, embedded_kailite)) if __name__ == '__main__': print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually From 9a8da35ec4a3d37f532a199c3244c0314ea28a61 Mon Sep 17 00:00:00 2001 From: SammCheese Date: Thu, 8 Jun 2023 06:18:23 +0200 Subject: [PATCH 2/8] working streaming. TODO: fix lite --- expose.cpp | 11 ++-- expose.h | 2 +- gpttype_adapter.cpp | 4 +- koboldcpp.py | 123 ++++++++++++++++++++++++++------------------ 4 files changed, 84 insertions(+), 56 deletions(-) diff --git a/expose.cpp b/expose.cpp index 283438a4b..8142e85cd 100644 --- a/expose.cpp +++ b/expose.cpp @@ -210,7 +210,6 @@ extern "C" generation_outputs generate(const generation_inputs inputs, generation_outputs &output) { - finished_stream = false; return gpttype_generate(inputs, output); } @@ -230,6 +229,12 @@ extern "C" bool has_finished() { return finished_stream; } + + + // TODO: dont duplicate code + void bind_set_stream_finished(bool status) { + finished_stream = status; + } } void receive_current_token(std::string token) { @@ -237,6 +242,6 @@ void receive_current_token(std::string token) { new_token_available = true; } -void set_stream_finished() { - finished_stream = true; +void set_stream_finished(bool status) { + finished_stream = status; } diff --git a/expose.h b/expose.h index dac7fffe0..fb360b48d 100644 --- a/expose.h +++ b/expose.h @@ -51,4 +51,4 @@ extern std::string executable_path; extern std::string lora_filename; extern void receive_current_token(std::string token); -extern void set_stream_finished(); +extern void set_stream_finished(bool status = true); diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index b49fd95f7..cdc0b5cf4 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -1041,7 +1041,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o fprintf(stderr, "Failed to predict\n"); snprintf(output.text, sizeof(output.text), "%s", ""); output.status = 0; - set_stream_finished(); + set_stream_finished(true); return output; } } @@ -1224,7 +1224,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); fflush(stdout); output.status = 1; - set_stream_finished(); + set_stream_finished(true); snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); return output; diff --git a/koboldcpp.py b/koboldcpp.py index 553460057..772023221 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -5,8 +5,9 @@ import ctypes import os import argparse -import json, http.server, threading, socket, sys, time, asyncio +import json, sys, time, asyncio from aiohttp import web +from concurrent.futures import ThreadPoolExecutor stop_token_max = 10 @@ -185,7 +186,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= else: inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0 inputs.seed = seed - for n in range(0,stop_token_max): + for n in range(stop_token_max): if not stop_sequence or n >= len(stop_sequence): inputs.stop_sequence[n] = "".encode("UTF-8") else: @@ -224,79 +225,96 @@ class ServerRequestHandler: self.port = port self.embedded_kailite = embedded_kailite + + async def generate_text(self, newprompt, genparams): + loop = asyncio.get_event_loop() + executor = ThreadPoolExecutor() + + def run_blocking(): + # Reset finished status before generating + handle.bind_set_stream_finished(False) + + return generate(prompt=newprompt, + max_context_length=genparams.get('max_context_length', maxctx), + max_length=genparams.get('max_length', 50), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 120), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) + + recvtxt = await loop.run_in_executor(executor, run_blocking) + + res = {"results": [{"text": recvtxt}]} + + try: + return res + except: + print("Generate: Error while generating") + + async def send_sse_event(self, response, event, data): await response.write(f'event: {event}\n'.encode()) await response.write(f'data: {data}\n\n'.encode()) - async def handle_sse_stream(self, request): response = web.StreamResponse(headers={"Content-Type": "text/event-stream"}) await response.prepare(request) - stream_finished = False - - while True: - if handle.has_finished(): - stream_finished = True + while not handle.has_finished(): if not handle.is_locked(): token = ctypes.string_at(handle.new_token()).decode('utf-8') - event_data = {"finished": stream_finished, "token": token} - event_str = f"data: {json.dumps(event_data)}" + event_data = {"token": token} + event_str = json.dumps(event_data) await self.send_sse_event(response, "message", event_str) - print(token) - print(event_data) - if stream_finished: - break + print(event_str) - async def generate_text(self, newprompt, genparams): - recvtxt = generate( - prompt=newprompt, - max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 50), - temperature=genparams.get('temperature', 0.8), - top_k=genparams.get('top_k', 120), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) - res = {"results": [{"text": recvtxt}]} + await asyncio.sleep(0) - try: - return web.json_response(res) - except: - print("Generate: The response could not be sent, maybe the connection was terminated?") + await response.write_eof() + await response.force_close() async def handle_request(self, request, genparams, newprompt, stream_flag): + tasks = [] + if stream_flag: - self.handle_sse_stream(request) - # RUN THESE CONCURRENTLY WITHOUT BLOCKING EACHOTHER - self.generate_text(newprompt, genparams) + tasks.append(self.handle_sse_stream(request,)) + + generate_task = asyncio.create_task(self.generate_text(newprompt, genparams)) + tasks.append(generate_task) + #tasks.append(self.generate_text(newprompt, genparams)) + + try: + await asyncio.gather(*tasks) + if not stream_flag: + generate_result = generate_task.result() + return generate_result + except Exception as e: + print(e) async def handle_get(self, request): global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock path = request.path.rstrip('/') - if path in ["", "/?"] or path.startswith(('/?','?')): + if path in ["", "/?"] or path.startswith(('/?', '?')): if args.stream and not "streaming=1" in path: - path = path.replace("streaming=0","") - if path.startswith(('/?','?')): + path = path.replace("streaming=0", "") + if path.startswith(('/?', '?')): path += "&streaming=1" else: path = path + "?streaming=1" - raise web.HTTPFound(path) if self.embedded_kailite is None: - return web.Response( - text="Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect." - ) + return web.Response(body=f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.".encode()) else: - return web.Response(body=self.embedded_kailite) + return web.Response(body=self.embedded_kailite, content_type='text/html') elif path.endswith(('/api/v1/model', '/api/latest/model')): return web.json_response({'result': friendlymodelname}) @@ -326,7 +344,7 @@ class ServerRequestHandler: body = await request.content.read() basic_api_flag = False kai_api_flag = False - kai_stream_flag = True + kai_sse_stream_flag = True path = request.path.rstrip('/') print(request) @@ -344,7 +362,7 @@ class ServerRequestHandler: if path.endswith('/api/v1/generate/stream'): kai_api_flag = True - kai_stream_flag = True + kai_sse_stream_flag = True if basic_api_flag or kai_api_flag: genparams = None @@ -362,10 +380,13 @@ class ServerRequestHandler: fullprompt = genparams.get('text', "") newprompt = fullprompt - await self.handle_request(request, genparams, newprompt, kai_stream_flag) + gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag) + + if not kai_sse_stream_flag: + return web.Response(body=gen) modelbusy = False - return web.Response() + return web.Response(); return web.Response(status=404) @@ -393,6 +414,8 @@ class ServerRequestHandler: await asyncio.sleep(3600) except KeyboardInterrupt: await runner.cleanup() + await site.stop() + await exit(1) async def run_server(addr, port, embedded_kailite=None): handler = ServerRequestHandler(addr, port, embedded_kailite) From b4e9e185d34d476153de8d6389fc65dfffb51fc9 Mon Sep 17 00:00:00 2001 From: SammCheese Date: Thu, 8 Jun 2023 15:21:00 +0200 Subject: [PATCH 3/8] fix legacy streaming --- koboldcpp.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 772023221..bfb534c13 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -5,7 +5,7 @@ import ctypes import os import argparse -import json, sys, time, asyncio +import json, sys, time, asyncio, socket from aiohttp import web from concurrent.futures import ThreadPoolExecutor @@ -255,8 +255,8 @@ class ServerRequestHandler: try: return res - except: - print("Generate: Error while generating") + except Exception as e: + print(f"Generate: Error while generating {e}") async def send_sse_event(self, response, event, data): @@ -273,7 +273,6 @@ class ServerRequestHandler: event_data = {"token": token} event_str = json.dumps(event_data) await self.send_sse_event(response, "message", event_str) - print(event_str) await asyncio.sleep(0) @@ -288,7 +287,6 @@ class ServerRequestHandler: generate_task = asyncio.create_task(self.generate_text(newprompt, genparams)) tasks.append(generate_task) - #tasks.append(self.generate_text(newprompt, genparams)) try: await asyncio.gather(*tasks) @@ -344,7 +342,7 @@ class ServerRequestHandler: body = await request.content.read() basic_api_flag = False kai_api_flag = False - kai_sse_stream_flag = True + kai_sse_stream_flag = False path = request.path.rstrip('/') print(request) @@ -382,10 +380,10 @@ class ServerRequestHandler: gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag) - if not kai_sse_stream_flag: - return web.Response(body=gen) - modelbusy = False + + if not kai_sse_stream_flag: + return web.Response(body=json.dumps(gen).encode()) return web.Response(); return web.Response(status=404) @@ -398,6 +396,11 @@ class ServerRequestHandler: async def start_server(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.addr, self.port)) + sock.listen(5) + self.app.router.add_route('GET', '/{tail:.*}', self.handle_get) self.app.router.add_route('POST', '/{tail:.*}', self.handle_post) self.app.router.add_route('OPTIONS', '/', self.handle_options) @@ -405,7 +408,7 @@ class ServerRequestHandler: runner = web.AppRunner(self.app) await runner.setup() - site = web.TCPSite(runner, self.addr, self.port) + site = web.SockSite(runner, sock) await site.start() # Keep Alive @@ -415,7 +418,11 @@ class ServerRequestHandler: except KeyboardInterrupt: await runner.cleanup() await site.stop() - await exit(1) + await sys.exit(0) + finally: + await runner.cleanup() + await site.stop() + await sys.exit(0) async def run_server(addr, port, embedded_kailite=None): handler = ServerRequestHandler(addr, port, embedded_kailite) From dee692a63e0801c24f371f49bda83d4f0c1e95a1 Mon Sep 17 00:00:00 2001 From: SammCheese Date: Thu, 8 Jun 2023 15:56:25 +0200 Subject: [PATCH 4/8] compability with basic_api, change api path to /extra --- koboldcpp.py | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index bfb534c13..7320b2ce0 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -226,7 +226,7 @@ class ServerRequestHandler: self.embedded_kailite = embedded_kailite - async def generate_text(self, newprompt, genparams): + async def generate_text(self, newprompt, genparams, basic_api_flag): loop = asyncio.get_event_loop() executor = ThreadPoolExecutor() @@ -234,6 +234,22 @@ class ServerRequestHandler: # Reset finished status before generating handle.bind_set_stream_finished(False) + if basic_api_flag: + return generate( + prompt=newprompt, + max_length=genparams.get('max', 50), + temperature=genparams.get('temperature', 0.8), + top_k=int(genparams.get('top_k', 120)), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) + return generate(prompt=newprompt, max_context_length=genparams.get('max_context_length', maxctx), max_length=genparams.get('max_length', 50), @@ -251,7 +267,9 @@ class ServerRequestHandler: recvtxt = await loop.run_in_executor(executor, run_blocking) - res = {"results": [{"text": recvtxt}]} + utfprint("\nOutput: " + recvtxt) + + res = {"data": {"seqs":[recvtxt]}} if basic_api_flag else {"results": [{"text": recvtxt}]} try: return res @@ -279,20 +297,19 @@ class ServerRequestHandler: await response.write_eof() await response.force_close() - async def handle_request(self, request, genparams, newprompt, stream_flag): + async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag): tasks = [] if stream_flag: tasks.append(self.handle_sse_stream(request,)) - generate_task = asyncio.create_task(self.generate_text(newprompt, genparams)) + generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag)) tasks.append(generate_task) try: await asyncio.gather(*tasks) - if not stream_flag: - generate_result = generate_task.result() - return generate_result + generate_result = generate_task.result() + return generate_result except Exception as e: print(e) @@ -344,7 +361,6 @@ class ServerRequestHandler: kai_api_flag = False kai_sse_stream_flag = False path = request.path.rstrip('/') - print(request) if modelbusy: return web.json_response( @@ -358,7 +374,7 @@ class ServerRequestHandler: if path.endswith(('/api/v1/generate', '/api/latest/generate')): kai_api_flag = True - if path.endswith('/api/v1/generate/stream'): + if path.endswith('/api/extra/generate/stream'): kai_api_flag = True kai_sse_stream_flag = True @@ -378,7 +394,7 @@ class ServerRequestHandler: fullprompt = genparams.get('text', "") newprompt = fullprompt - gen = await self.handle_request(request, genparams, newprompt, kai_sse_stream_flag) + gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag) modelbusy = False From 4f665cd63dfd5046cf792d8d220dc8431c1ac650 Mon Sep 17 00:00:00 2001 From: SammCheese Date: Fri, 9 Jun 2023 10:55:07 +0200 Subject: [PATCH 5/8] Squashed commit of the following: commit b617f2847b5914736ccf65bec22caaf49b39c0a8 Merge: 73cc5b8 92f44ff Author: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Fri Jun 9 16:10:35 2023 +0800 Merge branch 'master' into concedo_experimental commit 73cc5b88fbed75d540346bfad11cc5c1e0678705 Author: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Fri Jun 9 16:09:23 2023 +0800 added warning message for unsupported K quants commit 92f44ff7f778ef1b94028b2ba6d39943b5ca0ada Author: AT Date: Fri Jun 9 04:00:51 2023 -0400 metal : add GELU implementation (#1770) Co-authored-by: Adam Treat commit 245fc3c37da5ac5963f9f11a9f4f2ac08d96afc6 Author: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Fri Jun 9 10:39:59 2023 +0300 metal : faster q4_0 (#1775) * metal : 8% faster q4_0 Avoid copying into local uchar4 anf float4. * metal : 17% faster Q4_0 Use 64 threads in a thread group. --------- Co-authored-by: Iwan Kawrakow commit 01dc509038d5288c9139c60005aba63c0565b379 Merge: 0833845 72ff528 Author: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Fri Jun 9 14:53:35 2023 +0800 Merge branch 'master' into concedo_experimental commit 0833845268339719a490269faefe66ac1d2d1dd5 Author: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Fri Jun 9 14:38:31 2023 +0800 merged metal patch directly into the file commit 72ff5282bf0388c60821f504c4c8cc2b1f491aa6 Author: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu Jun 8 22:28:21 2023 +0300 metal : add Q2_K implementation (#1762) * metal : add Q2_K implementation 27.1 ms / token on M2 Max 30-core GPU, so about the same speed as Q4_0. Memory throughput is ~156 GB/s. The access pattern used in the Q2_K CUDA implementation resulted in significantly lower performance (~31 ms/token). * Fixing merge conflicts --------- Co-authored-by: Iwan Kawrakow commit 0bf7cf1b296fc9fca05411b37afdf08a531487d2 Author: Georgi Gerganov Date: Thu Jun 8 20:48:14 2023 +0300 Revert "ggml : load data into int8x16x4_t using vld4q_s8 on arm64 (#1738)" This reverts commit 8432d4d9f716b25133e3ed671d91e21f6f3be867. commit 8432d4d9f716b25133e3ed671d91e21f6f3be867 Author: le.chang Date: Fri Jun 9 00:47:56 2023 +0800 ggml : load data into int8x16x4_t using vld4q_s8 on arm64 (#1738) commit 6fa1613f15c7b92fa1279426dc15eae541d0e7be Author: Hyun-joo KIM Date: Fri Jun 9 01:47:36 2023 +0900 Metal inference enhancement - put hard-wired relative path of ggml-model.model file using a patch file due to lack of NSBundle environment commit 0f291e1f65c1d68201e71ce99c89562a36686b6d Author: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu Jun 8 19:46:22 2023 +0300 metal : Q6_K implementation (#1752) * Metal implementation for Q4_K Very slow for now: 42 ms / token, Q4_0 runs in 28 ms/token on my 30-core M2 Max GPU. * Optimizing Q4_K on metal The first token always takes longer, I guess because the metal kernel is being jit-compiled. So, using n = 128 to measure time. At this point Q4_K takes 29.5 ms / token compared to 27.2 ms / token for Q4_0. Quite a bit better than the initial attempt, but still not good enough. * Optimizing q4_K metal dot some more For n = 256 it is now 28.1 ms/token compared to 27 ms/token for q4_0. * Fix after merge with master * Metal implementation for Q6_K Similar to the CUDA implementation. No idea if this is the optimum for Metal, but the few alternative variants I tried all had a lower performance. We get 36.5 ms / token on M2 Max with 30 GPU cores. This corresponds to ~200 GB/second throughput. * clang-tidy : add config back * Much better Q6_K implementation for metal 28.3 ms / token for 7B. Subtracting ~9 ms that is spent in other compute graph operations, we are left with ~19 ms for the matrix multiplications. The model is ~5.5 GB, so we are getting 1000 / 19 * 5.5 = 290 GB/s! --------- Co-authored-by: Iwan Kawrakow commit 7f181600c77efb48a1b2a2e30ff0cd50c294ebea Author: Hyun-joo KIM Date: Fri Jun 9 01:24:22 2023 +0900 Metal inference enhancement - put hard-wired relative path of ggml-model.model file due to lack of NSBundle environment commit 8fc8179919a11738910db07a800f2b176f8adf09 Author: qingfengfenga <41416092+qingfengfenga@users.noreply.github.com> Date: Thu Jun 8 15:58:53 2023 +0800 Add llama.cpp docker support for non-latin languages (#1673) * Modify Dockerfile default character set to improve compatibility (#1673) commit b50b570ed9d699d3d126d72fc02de92926bcd937 Author: Steven Roussey Date: Thu Jun 8 00:12:28 2023 -0700 ggml : fix fprintf warnings (#1720) commit 53aba3f393f2e02a78ddaba2e934893a8bbf3246 Author: Georgi Gerganov Date: Thu Jun 8 10:09:08 2023 +0300 clang-tidy : restore dot file from accidental deletion commit 4161bdc04debb70bf5f275492b4d89fd9330087c Author: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Thu Jun 8 10:08:23 2023 +0300 metal : add Q4_K implementation (#1733) * Metal implementation for Q4_K Very slow for now: 42 ms / token, Q4_0 runs in 28 ms/token on my 30-core M2 Max GPU. * Optimizing Q4_K on metal The first token always takes longer, I guess because the metal kernel is being jit-compiled. So, using n = 128 to measure time. At this point Q4_K takes 29.5 ms / token compared to 27.2 ms / token for Q4_0. Quite a bit better than the initial attempt, but still not good enough. * Optimizing q4_K metal dot some more For n = 256 it is now 28.1 ms/token compared to 27 ms/token for q4_0. * Fix after merge with master --------- Co-authored-by: Iwan Kawrakow commit 0035858273ebe0694926bf4414d279f3e1cd109d Author: johnson442 <56517414+johnson442@users.noreply.github.com> Date: Thu Jun 8 08:02:48 2023 +0100 k-quants : add missing compile definition to CMakeLists (#1748) --- ggml-metal.m | 77 ++++++- ggml-metal.metal | 549 +++++++++++++++++++++++++++++++++++++++++++++-- ggml.c | 22 +- llama.cpp | 8 + 4 files changed, 620 insertions(+), 36 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 0953af6a4..89b17ce5e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -45,13 +45,20 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(scale); GGML_METAL_DECL_KERNEL(silu); GGML_METAL_DECL_KERNEL(relu); + GGML_METAL_DECL_KERNEL(gelu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); + GGML_METAL_DECL_KERNEL(get_rows_q2_k); + GGML_METAL_DECL_KERNEL(get_rows_q4_k); + GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -99,7 +106,7 @@ struct ggml_metal_context * ggml_metal_init(void) { NSError * error = nil; //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; - NSString * path = [[NSBundle mainBundle] pathForResource:@"ggml-metal" ofType:@"metal"]; + NSString * path = @"./ggml-metal.metal"; fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]); NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; @@ -129,13 +136,20 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(scale); GGML_METAL_ADD_KERNEL(silu); GGML_METAL_ADD_KERNEL(relu); + GGML_METAL_ADD_KERNEL(gelu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); + GGML_METAL_ADD_KERNEL(get_rows_q2_k); + GGML_METAL_ADD_KERNEL(get_rows_q4_k); + GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -408,6 +422,20 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_GELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + [encoder setComputePipelineState:ctx->pipeline_gelu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_SOFT_MAX: { if (encoder == nil) { @@ -514,10 +542,41 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne12 == 1); nth0 = 8; - nth1 = 4; + nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; - default: GGML_ASSERT(false && "not implemented"); + case GGML_TYPE_Q2_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; + } break; + case GGML_TYPE_Q4_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; + } break; + case GGML_TYPE_Q6_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32]; + } break; + default: + { + fprintf(stderr, "Asserting on type %d\n",(int)src0t); + GGML_ASSERT(false && "not implemented"); + } }; @@ -540,6 +599,15 @@ void ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q2_K) { + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q4_K) { + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q6_K) { + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -555,6 +623,9 @@ void ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index a359bebe2..745fe8ad3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -81,6 +81,17 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +constant float GELU_COEF_A = 0.044715f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + float x = src0[tpig]; + dst[tpig] = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + kernel void kernel_soft_max( device const float * src0, device float * dst, @@ -267,6 +278,8 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tptg[[threads_per_threadgroup]]) { const int nb = ne00/QK4_0; + const int8_t m8 = 8; + const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -276,45 +289,65 @@ kernel void kernel_mul_mat_q4_0_f32( const uint nth = tptg.x*tptg.y; const uint ith = tptg.y*tpitg.x + tpitg.y; - sum[ith] = 0.0f; + const int ix = tpitg.y/4; // 0 or 1 + const int iy = tpitg.y - 4*ix; // 0...3 - for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs; - device const float4 * y0p = (device const float4 *) (y + i*QK4_0); + const int first = 4 * iy; - const float d = (float)((x + i)->d); + float sumf = 0; - const uchar4 x0v = *(x0p + tpitg.y); - const float4 y0v = *(y0p + tpitg.y + 0); - const float4 y1v = *(y0p + tpitg.y + 4); + for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { - float acc = 0.0f; + const float d = (float)x[i].d; + + device const uint8_t * xl = x[i].qs + first; + device const float * yl = y + i * QK4_0 + first; + + float2 acc = {0.0f, 0.0f}; for (int j = 0; j < 4; ++j) { - const int x0 = x0v[j] & 0x0F; - const int x1 = x0v[j] >> 4; - const float y0 = y0v[j]; - const float y1 = y1v[j]; + acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8); + acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8); - acc += (x0 - 8)*y0 + (x1 - 8)*y1; } - sum[ith] += acc*d; + sumf += d * (acc[0] + acc[1]); } - // accumulate the sum from all threads in the threadgroup + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // This version is slightly faster than the commented out one below, + // which I copy-pasted from ggerganov's q4_0 dot product for metal. + // threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = nth/2; i > 0; i /= 2) { - if (ith < i) { - sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; } - + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } + + //// accumulate the sum from all threads in the threadgroup + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = nth/2; i > 0; i /= 2) { + // if (ith < i) { + // sum[ith] += sum[ith + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + //if (ith == 0) { + // dst[r1*ne0 + r0] = sum[0]; + //} } kernel void kernel_mul_mat_f16_f32( @@ -338,6 +371,7 @@ kernel void kernel_mul_mat_f16_f32( uint3 tpig[[thread_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]]) { + const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; @@ -503,3 +537,474 @@ kernel void kernel_cpy_f32_f32( dst_data[i00] = src[0]; } } + +//============================================ k-quants ====================================================== + +#define QK_K 256 + +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins +} block_q2_k; + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_k; + +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_k; + +static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { + uchar4 r; + if (j < 4) { + r[0] = q[j+0] & 63; r[1] = q[j+4] & 63; + r[2] = q[j+1] & 63; r[3] = q[j+5] & 63; + } else { + r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); + r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); + } + return r; +} + +//========================================== dequantization ============================= + +static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = x[i].d; + const float min = x[i].dmin; + + device const uint8_t * q = x[i].qs; + + int is = 0; + float dl, ml; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + uint8_t sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; + + sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; + + shift += 2; + } + q += 32; + } + + } +} + +static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = x[i].d; + const float min = x[i].dmin; + + device const uint8_t * q = x[i].qs; + device const uint8_t * scales = x[i].scales; + + int is = 0; + for (int j = 0; j < QK_K; j += 64) { + const uchar4 sc = get_scale_min_k4(is, scales); + const float d1 = d * sc[0]; const float m1 = min * sc[1]; + const float d2 = d * sc[2]; const float m2 = min * sc[3]; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } + + } +} + +static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + device const uint8_t * ql = x[i].ql; + device const uint8_t * qh = x[i].qh; + device const int8_t * sc = x[i].scales; + + const float d = x[i].d; + + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + +kernel void kernel_get_rows_q2_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q2_k( + (device const block_q2_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + +kernel void kernel_get_rows_q4_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q4_k( + (device const block_q4_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + +kernel void kernel_get_rows_q6_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q6_k( + (device const block_q6_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + +//====================================== dot products ========================= + +kernel void kernel_mul_mat_q2_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], // we don't use this for now + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; + + + const int tid = tpitg.y; // 0...16 + const int il = tid/4; // 0...3 + const int ir = tid%4; // 0...3 + const int ip = il/2; // 0 or 1 + const int shift1 = 4*(il%2);// 0 or 4 + const int shift2 = shift1+2;// 2 or 6 + const int n = 8; + const int is = 4*il + (n*ir)/16; + + sum[ith] = 0.0f; + + float sumf = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + device const uint8_t * q = x[i].qs + 32*ip + n*ir; + device const uint8_t * scales = x[i].scales + is; + + uint8_t d1 = scales[0] & 0xF; + uint8_t m1 = scales[0] >> 4; + uint8_t d2 = scales[2] & 0xF; + uint8_t m2 = scales[2] >> 4; + + device const float * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = (float)x[i].d; + const float dmin = (float)x[i].dmin; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0]; + s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32]; + } + sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2); + + + } + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // This version is slightly faster than the commented out one below, + // which I copy-pasted from ggerganov's q4_0 dot product for metal. + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } + + //// accumulate the sum from all threads in the threadgroup + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = nth/2; i > 0; i /= 2) { + // if (ith < i) { + // sum[ith] += sum[ith + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + //if (ith == 0) { + // dst[r1*ne0 + r0] = sum[0]; + //} +} + +kernel void kernel_mul_mat_q4_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], // we don't use this for now + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const uint nth = tptg.x*tptg.y; + const uint ith = tptg.y*tpitg.x + tpitg.y; + + const int tid = tpitg.y; // 0...16 + const int il = tid/4; // 0...3 + const int ir = tid%4; // 0...3 + const int n = 8; + const int is = 2*il; + + sum[ith] = 0.0f; + + float sumf = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + device const uint8_t * q = (x + i)->qs + 32*il + n*ir; + device const float * y = yy + i*QK_K + 64*il + n*ir; + device const uint8_t * scales = (x + i)->scales; + + const float dall = (float)((x + i)->d); + const float dmin = (float)((x + i)->dmin); + + const uchar4 sc = get_scale_min_k4(is, scales); + + float4 s = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0]; + s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32]; + } + sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]); + + } + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // This version is slightly faster than the commented out one below, + // which I copy-pasted from ggerganov's q4_0 dot product for metal. + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } + + //// accumulate the sum from all threads in the threadgroup + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = nth/2; i > 0; i /= 2) { + // if (ith < i) { + // sum[ith] += sum[ith + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + //if (ith == 0) { + // dst[r1*ne0 + r0] = sum[0]; + //} +} + +kernel void kernel_mul_mat_q6_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], // we don't use this for now + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const uint nth = tptg.x*tptg.y; + const uint ith = tptg.y*tpitg.x + tpitg.y; + + const int step = QK_K / tptg.y; // we expect this to be 16 + const int iqs = step * tpitg.y; // 0...240 in steps of 16 + const int ip = iqs / 128; // 0 or 1 + const int il = (iqs - 128*ip)/16; // 0...7 + const int n = 4; + const int is = 8*ip + (n*il)/16; + + float sumf = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + device const uint8_t * ql = x[i].ql + 64*ip + n*il; + device const uint8_t * qh = x[i].qh + 32*ip + n*il; + device const int8_t * sc = x[i].scales + is; + + device const float * y = yy + i * QK_K + 128*ip + n*il; + + const float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + } + + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } + +} diff --git a/ggml.c b/ggml.c index 9d4d3583a..3b72b80f3 100644 --- a/ggml.c +++ b/ggml.c @@ -14729,12 +14729,12 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-12s %8d %8d %d %d %d %16zu %16zu %16zu %16zu %16p %32s\n", + fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", ggml_type_name(tensor->type), ggml_op_name (tensor->op), tensor->n_dims, - (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3], - nb[0], nb[1], nb[2], nb[3], + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], tensor->data, tensor->name); } @@ -14743,13 +14743,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-6s %-12s %8d %d %d %d %d %16zu %16zu %16zu %16zu %8d %16p %32s\n", + fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n", arg, ggml_type_name(tensor->type), ggml_op_name (tensor->op), tensor->n_dims, - (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3], - nb[0], nb[1], nb[2], nb[3], + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], tensor->n_tasks, tensor->data, tensor->name); @@ -14772,11 +14772,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { FILE * fout = stdout; fprintf(fout, "\n"); - fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); - fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); - fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); - fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); - fprintf(fout, "%-16s %8d\n", "eval", (int) size_eval); + fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); + fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); + fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); + fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); + fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval); // header fprintf(fout, "\n"); diff --git a/llama.cpp b/llama.cpp index d80706446..1e2c9d767 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1028,6 +1028,14 @@ static void llama_model_load_internal( } } + #if defined(GGML_USE_CLBLAST) + if (file_version == LLAMA_FILE_VERSION_GGJT_V3) { + if (hparams.ftype >= LLAMA_FTYPE_MOSTLY_Q2_K && hparams.ftype <= LLAMA_FTYPE_MOSTLY_Q6_K) { + printf("\n===\nK-Quants are currently not supported with CLBlast!!!\nPlease select a q4_0, q4_0, q5_0 or q5_1 format instead!\n=====\n"); + } + } + #endif + if (vocab_only) { return; } From e6231c30553b0720ffdda04106625e3a56b32ae5 Mon Sep 17 00:00:00 2001 From: SammCheese Date: Fri, 9 Jun 2023 12:17:55 +0200 Subject: [PATCH 6/8] back to http.server, improved implementation --- expose.cpp | 35 ++---- expose.h | 4 +- gpttype_adapter.cpp | 9 +- koboldcpp.py | 298 ++++++++++++++++++++++++++------------------ 4 files changed, 196 insertions(+), 150 deletions(-) diff --git a/expose.cpp b/expose.cpp index 8142e85cd..6e4f656e1 100644 --- a/expose.cpp +++ b/expose.cpp @@ -24,9 +24,8 @@ std::string executable_path = ""; std::string lora_filename = ""; -static std::string current_token = ""; -static bool new_token_available = false; -static bool finished_stream = false; +bool generation_finished; +std::vector generated_tokens; extern "C" { @@ -213,35 +212,17 @@ extern "C" return gpttype_generate(inputs, output); } + const char* new_token(int idx) { + if (generated_tokens.size() <= idx || idx < 0) return nullptr; - const char* new_token() { - if (new_token_available) { - new_token_available = false; - return current_token.c_str(); - } - return nullptr; + return generated_tokens[idx].c_str(); } - bool is_locked() { - return !new_token_available; + int get_stream_count() { + return generated_tokens.size(); } bool has_finished() { - return finished_stream; - } - - - // TODO: dont duplicate code - void bind_set_stream_finished(bool status) { - finished_stream = status; + return generation_finished; } } - -void receive_current_token(std::string token) { - current_token = token; - new_token_available = true; -} - -void set_stream_finished(bool status) { - finished_stream = status; -} diff --git a/expose.h b/expose.h index fb360b48d..bb9c5920b 100644 --- a/expose.h +++ b/expose.h @@ -50,5 +50,5 @@ struct generation_outputs extern std::string executable_path; extern std::string lora_filename; -extern void receive_current_token(std::string token); -extern void set_stream_finished(bool status = true); +extern std::vector generated_tokens; +extern bool generation_finished; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index cdc0b5cf4..1c75a8b7c 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -736,6 +736,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o params.n_batch = n_batch; params.n_threads = n_threads; + generation_finished = false; // Set current generation status + generated_tokens.clear(); // New Generation, new tokens + if (params.repeat_last_n < 1) { params.repeat_last_n = 1; @@ -1041,7 +1044,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o fprintf(stderr, "Failed to predict\n"); snprintf(output.text, sizeof(output.text), "%s", ""); output.status = 0; - set_stream_finished(true); + generation_finished = true; return output; } } @@ -1155,7 +1158,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o if (stream_sse) { - receive_current_token(tokenizedstr); + generated_tokens.push_back(tokenizedstr); } concat_output += tokenizedstr; } @@ -1224,7 +1227,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("\nTime Taken - Processing:%.1fs (%.0fms/T), Generation:%.1fs (%.0fms/T), Total:%.1fs", time1, pt1, time2, pt2, (time1 + time2)); fflush(stdout); output.status = 1; - set_stream_finished(true); + generation_finished = true; snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); return output; diff --git a/koboldcpp.py b/koboldcpp.py index 7320b2ce0..e7ea0748e 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -5,8 +5,7 @@ import ctypes import os import argparse -import json, sys, time, asyncio, socket -from aiohttp import web +import json, sys, http.server, time, asyncio, socket, threading from concurrent.futures import ThreadPoolExecutor stop_token_max = 10 @@ -137,6 +136,9 @@ def init_library(): handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever handle.generate.restype = generation_outputs handle.new_token.restype = ctypes.c_char_p + handle.new_token.argtypes = [ctypes.c_int] + handle.get_stream_count.restype = ctypes.c_int + handle.has_finished.restype = ctypes.c_bool def load_model(model_filename): inputs = load_model_inputs() @@ -215,25 +217,23 @@ modelbusy = False defaultport = 5001 KcppVersion = "1.29" -class ServerRequestHandler: +class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): sys_version = "" server_version = "ConcedoLlamaForKoboldServer" - app = web.Application() def __init__(self, addr, port, embedded_kailite): self.addr = addr self.port = port self.embedded_kailite = embedded_kailite + def __call__(self, *args, **kwargs): + super().__init__(*args, **kwargs) async def generate_text(self, newprompt, genparams, basic_api_flag): loop = asyncio.get_event_loop() executor = ThreadPoolExecutor() def run_blocking(): - # Reset finished status before generating - handle.bind_set_stream_finished(False) - if basic_api_flag: return generate( prompt=newprompt, @@ -249,21 +249,21 @@ class ServerRequestHandler: seed=genparams.get('sampler_seed', -1), stop_sequence=genparams.get('stop_sequence', []) ) - - return generate(prompt=newprompt, - max_context_length=genparams.get('max_context_length', maxctx), - max_length=genparams.get('max_length', 50), - temperature=genparams.get('temperature', 0.8), - top_k=genparams.get('top_k', 120), - top_a=genparams.get('top_a', 0.0), - top_p=genparams.get('top_p', 0.85), - typical_p=genparams.get('typical', 1.0), - tfs=genparams.get('tfs', 1.0), - rep_pen=genparams.get('rep_pen', 1.1), - rep_pen_range=genparams.get('rep_pen_range', 128), - seed=genparams.get('sampler_seed', -1), - stop_sequence=genparams.get('stop_sequence', []) - ) + else: + return generate(prompt=newprompt, + max_context_length=genparams.get('max_context_length', maxctx), + max_length=genparams.get('max_length', 50), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 120), + top_a=genparams.get('top_a', 0.0), + top_p=genparams.get('top_p', 0.85), + typical_p=genparams.get('typical', 1.0), + tfs=genparams.get('tfs', 1.0), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=genparams.get('sampler_seed', -1), + stop_sequence=genparams.get('stop_sequence', []) + ) recvtxt = await loop.run_in_executor(executor, run_blocking) @@ -277,104 +277,139 @@ class ServerRequestHandler: print(f"Generate: Error while generating {e}") - async def send_sse_event(self, response, event, data): - await response.write(f'event: {event}\n'.encode()) - await response.write(f'data: {data}\n\n'.encode()) + async def send_sse_event(self, event, data): + self.wfile.write(f'event: {event}\n'.encode()) + self.wfile.write(f'data: {data}\n\n'.encode()) - async def handle_sse_stream(self, request): - response = web.StreamResponse(headers={"Content-Type": "text/event-stream"}) - await response.prepare(request) + + async def handle_sse_stream(self): + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") + self.send_header("Connection", "keep-alive") + self.end_headers() + + current_token = 0; while not handle.has_finished(): - if not handle.is_locked(): - token = ctypes.string_at(handle.new_token()).decode('utf-8') - event_data = {"token": token} + if current_token < handle.get_stream_count(): + token = handle.new_token(current_token) + + if token is None: # Token isnt ready yet, received nullpointer + continue + + current_token += 1 + + tokenStr = ctypes.string_at(token).decode('utf-8') + event_data = {"token": tokenStr} event_str = json.dumps(event_data) - await self.send_sse_event(response, "message", event_str) + await self.send_sse_event("message", event_str) await asyncio.sleep(0) - await response.write_eof() - await response.force_close() + await self.wfile.close() - async def handle_request(self, request, genparams, newprompt, basic_api_flag, stream_flag): + + async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag): tasks = [] if stream_flag: - tasks.append(self.handle_sse_stream(request,)) + tasks.append(self.handle_sse_stream()) generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag)) tasks.append(generate_task) try: await asyncio.gather(*tasks) + print("done") generate_result = generate_task.result() return generate_result except Exception as e: print(e) - async def handle_get(self, request): + def do_GET(self): global maxctx, maxlen, friendlymodelname, KcppVersion, streamLock - path = request.path.rstrip('/') + self.path = self.path.rstrip('/') + response_body = None - if path in ["", "/?"] or path.startswith(('/?', '?')): - if args.stream and not "streaming=1" in path: - path = path.replace("streaming=0", "") - if path.startswith(('/?', '?')): - path += "&streaming=1" + if self.path in ["", "/?"] or self.path.startswith(('/?','?')): #it's possible for the root url to have ?params without / + if args.stream and not "streaming=1" in self.path: + self.path = self.path.replace("streaming=0","") + if self.path.startswith(('/?','?')): + self.path += "&streaming=1" else: - path = path + "?streaming=1" + self.path = self.path + "?streaming=1" + self.send_response(302) + self.send_header("Location", self.path) + self.end_headers() + print("Force redirect to streaming mode, as --stream is set.") + return None if self.embedded_kailite is None: - return web.Response(body=f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.".encode()) + response_body = (f"Embedded Kobold Lite is not found.
You will have to connect via the main KoboldAI client, or use this URL to connect.").encode() else: - return web.Response(body=self.embedded_kailite, content_type='text/html') + response_body = self.embedded_kailite - elif path.endswith(('/api/v1/model', '/api/latest/model')): - return web.json_response({'result': friendlymodelname}) + elif self.path.endswith(('/api/v1/model', '/api/latest/model')): + response_body = (json.dumps({'result': friendlymodelname }).encode()) - elif path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): - return web.json_response({"value": maxlen}) + elif self.path.endswith(('/api/v1/config/max_length', '/api/latest/config/max_length')): + response_body = (json.dumps({"value": maxlen}).encode()) - elif path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): - return web.json_response({"value": maxctx}) + elif self.path.endswith(('/api/v1/config/max_context_length', '/api/latest/config/max_context_length')): + response_body = (json.dumps({"value": maxctx}).encode()) - elif path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): - return web.json_response({"value": ""}) + elif self.path.endswith(('/api/v1/config/soft_prompt', '/api/latest/config/soft_prompt')): + response_body = (json.dumps({"value":""}).encode()) - elif path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): - return web.json_response({"values": []}) + elif self.path.endswith(('/api/v1/config/soft_prompts_list', '/api/latest/config/soft_prompts_list')): + response_body = (json.dumps({"values": []}).encode()) - elif path.endswith(('/api/v1/info/version', '/api/latest/info/version')): - return web.json_response({"result": "1.2.2"}) + elif self.path.endswith(('/api/v1/info/version', '/api/latest/info/version')): + response_body = (json.dumps({"result":"1.2.2"}).encode()) - elif path.endswith(('/api/extra/version')): - return web.json_response({"result": "KoboldCpp", "version": KcppVersion}) + elif self.path.endswith(('/api/extra/version')): + response_body = (json.dumps({"result":"KoboldCpp","version":KcppVersion}).encode()) - return web.Response(status=404, text="Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.") + if response_body is None: + self.send_response(404) + self.end_headers() + rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.' + self.wfile.write(rp.encode()) + else: + self.send_response(200) + self.send_header('Content-Length', str(len(response_body))) + self.end_headers() + self.wfile.write(response_body) + return - async def handle_post(self, request): + def do_POST(self): global modelbusy - body = await request.content.read() + content_length = int(self.headers['Content-Length']) + body = self.rfile.read(content_length) basic_api_flag = False kai_api_flag = False kai_sse_stream_flag = False - path = request.path.rstrip('/') + self.path = self.path.rstrip('/') + if modelbusy: - return web.json_response( - {"detail": {"msg": "Server is busy; please try again later.", "type": "service_unavailable"}}, - status=503, - ) + self.send_response(503) + self.end_headers() + self.wfile.write(json.dumps({"detail": { + "msg": "Server is busy; please try again later.", + "type": "service_unavailable", + }}).encode()) + return - if path.endswith('/request'): + if self.path.endswith('/request'): basic_api_flag = True - if path.endswith(('/api/v1/generate', '/api/latest/generate')): + if self.path.endswith(('/api/v1/generate', '/api/latest/generate')): kai_api_flag = True - if path.endswith('/api/extra/generate/stream'): + if self.path.endswith('/api/extra/generate/stream'): kai_api_flag = True kai_sse_stream_flag = True @@ -383,66 +418,94 @@ class ServerRequestHandler: try: genparams = json.loads(body) except ValueError as e: - return web.Response(status=503) + return self.send_response(503) utfprint("\nInput: " + json.dumps(genparams)) modelbusy = True + if kai_api_flag: fullprompt = genparams.get('prompt', "") else: fullprompt = genparams.get('text', "") newprompt = fullprompt - gen = await self.handle_request(request, genparams, newprompt, basic_api_flag, kai_sse_stream_flag) + gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag)) + + try: + self.wfile.write(json.dumps(gen).encode()) + except: + print("Generate: The response could not be sent, maybe connection was terminated?") modelbusy = False - if not kai_sse_stream_flag: - return web.Response(body=json.dumps(gen).encode()) - return web.Response(); + return - return web.Response(status=404) + self.send_response(404) + self.end_headers() - async def handle_options(self): - return web.Response() - async def handle_head(self): - return web.Response() + def do_OPTIONS(self): + self.send_response(200) + self.end_headers() - async def start_server(self): + def do_HEAD(self): + self.send_response(200) + self.end_headers() - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((self.addr, self.port)) - sock.listen(5) + def end_headers(self): + self.send_header('Access-Control-Allow-Origin', '*') + self.send_header('Access-Control-Allow-Methods', '*') + self.send_header('Access-Control-Allow-Headers', '*') + if "/api" in self.path: + if self.path.endswith("/stream"): + self.send_header('Content-type', 'text/event-stream') + self.send_header('Content-type', 'application/json') + else: + self.send_header('Content-type', 'text/html') + return super(ServerRequestHandler, self).end_headers() - self.app.router.add_route('GET', '/{tail:.*}', self.handle_get) - self.app.router.add_route('POST', '/{tail:.*}', self.handle_post) - self.app.router.add_route('OPTIONS', '/', self.handle_options) - self.app.router.add_route('HEAD', '/', self.handle_head) - runner = web.AppRunner(self.app) - await runner.setup() - site = web.SockSite(runner, sock) - await site.start() +def RunServerMultiThreaded(addr, port, embedded_kailite = None): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((addr, port)) + sock.listen(5) - # Keep Alive + class Thread(threading.Thread): + def __init__(self, i): + threading.Thread.__init__(self) + self.i = i + self.daemon = True + self.start() + + def run(self): + handler = ServerRequestHandler(addr, port, embedded_kailite) + with http.server.HTTPServer((addr, port), handler, False) as self.httpd: + try: + self.httpd.socket = sock + self.httpd.server_bind = self.server_close = lambda self: None + self.httpd.serve_forever() + except (KeyboardInterrupt,SystemExit): + self.httpd.server_close() + sys.exit(0) + finally: + self.httpd.server_close() + sys.exit(0) + def stop(self): + self.httpd.server_close() + + numThreads = 6 + threadArr = [] + for i in range(numThreads): + threadArr.append(Thread(i)) + while 1: try: - while True: - await asyncio.sleep(3600) + time.sleep(10) except KeyboardInterrupt: - await runner.cleanup() - await site.stop() - await sys.exit(0) - finally: - await runner.cleanup() - await site.stop() - await sys.exit(0) - -async def run_server(addr, port, embedded_kailite=None): - handler = ServerRequestHandler(addr, port, embedded_kailite) - await handler.start_server() + for i in range(numThreads): + threadArr[i].stop() + sys.exit(0) def show_gui(): @@ -514,15 +577,14 @@ def show_gui(): unbantokens = tk.IntVar() highpriority = tk.IntVar() disablemmap = tk.IntVar() - frm3 = tk.Frame(root) - tk.Checkbutton(frm3, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) - tk.Checkbutton(frm3, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) - tk.Checkbutton(frm3, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) - tk.Checkbutton(frm3, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) - tk.Checkbutton(frm3, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) - tk.Checkbutton(frm3, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) - - frm3.grid(row=5,column=0,pady=4) + frameD = tk.Frame(root) + tk.Checkbutton(frameD, text='Streaming Mode',variable=stream, onvalue=1, offvalue=0).grid(row=0,column=0) + tk.Checkbutton(frameD, text='Use SmartContext',variable=smartcontext, onvalue=1, offvalue=0).grid(row=0,column=1) + tk.Checkbutton(frameD, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) + tk.Checkbutton(frameD, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) + tk.Checkbutton(frameD, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) + tk.Checkbutton(frameD, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) + frameD.grid(row=5,column=0,pady=4) # Create button, it will change label text tk.Button( root , text = "Launch", font = ("Impact", 18), bg='#54FA9B', command = guilaunch ).grid(row=6,column=0) @@ -702,7 +764,7 @@ def main(args): except: print("--launch was set, but could not launch web browser automatically.") print(f"Please connect to custom endpoint at {epurl}") - asyncio.run(run_server(args.host, args.port, embedded_kailite)) + asyncio.run(RunServerMultiThreaded(args.host, args.port, embedded_kailite)) if __name__ == '__main__': print("Welcome to KoboldCpp - Version " + KcppVersion) # just update version manually From c99ab9df33f21234473f5f7653130a5424de36c7 Mon Sep 17 00:00:00 2001 From: SammCheese Date: Fri, 9 Jun 2023 12:19:08 +0200 Subject: [PATCH 7/8] Revert "Squashed commit of the following:" This reverts commit 4f665cd63dfd5046cf792d8d220dc8431c1ac650. --- ggml-metal.m | 77 +------ ggml-metal.metal | 551 ++--------------------------------------------- ggml.c | 22 +- llama.cpp | 8 - 4 files changed, 37 insertions(+), 621 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 89b17ce5e..0953af6a4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -45,20 +45,13 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(scale); GGML_METAL_DECL_KERNEL(silu); GGML_METAL_DECL_KERNEL(relu); - GGML_METAL_DECL_KERNEL(gelu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); - GGML_METAL_DECL_KERNEL(get_rows_q2_k); - GGML_METAL_DECL_KERNEL(get_rows_q4_k); - GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -106,7 +99,7 @@ struct ggml_metal_context * ggml_metal_init(void) { NSError * error = nil; //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; - NSString * path = @"./ggml-metal.metal"; + NSString * path = [[NSBundle mainBundle] pathForResource:@"ggml-metal" ofType:@"metal"]; fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]); NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; @@ -136,20 +129,13 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(scale); GGML_METAL_ADD_KERNEL(silu); GGML_METAL_ADD_KERNEL(relu); - GGML_METAL_ADD_KERNEL(gelu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); - GGML_METAL_ADD_KERNEL(get_rows_q2_k); - GGML_METAL_ADD_KERNEL(get_rows_q4_k); - GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); - GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -422,20 +408,6 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; - case GGML_OP_GELU: - { - if (encoder == nil) { - encoder = [command_buffer computeCommandEncoder]; - } - - [encoder setComputePipelineState:ctx->pipeline_gelu]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - - const int64_t n = ggml_nelements(dst); - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; case GGML_OP_SOFT_MAX: { if (encoder == nil) { @@ -542,41 +514,10 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne12 == 1); nth0 = 8; - nth1 = 8; + nth1 = 4; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; - case GGML_TYPE_Q2_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 4; - nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; - } break; - case GGML_TYPE_Q4_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 4; - nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; - } break; - case GGML_TYPE_Q6_K: - { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - - nth0 = 4; - nth1 = 16; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32]; - } break; - default: - { - fprintf(stderr, "Asserting on type %d\n",(int)src0t); - GGML_ASSERT(false && "not implemented"); - } + default: GGML_ASSERT(false && "not implemented"); }; @@ -599,15 +540,6 @@ void ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else if (src0t == GGML_TYPE_Q2_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else if (src0t == GGML_TYPE_Q4_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else if (src0t == GGML_TYPE_Q6_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -623,9 +555,6 @@ void ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; - case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; - case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; - case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 745fe8ad3..a359bebe2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -81,17 +81,6 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } -constant float GELU_COEF_A = 0.044715f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - -kernel void kernel_gelu( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - float x = src0[tpig]; - dst[tpig] = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - kernel void kernel_soft_max( device const float * src0, device float * dst, @@ -278,8 +267,6 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tptg[[threads_per_threadgroup]]) { const int nb = ne00/QK4_0; - const int8_t m8 = 8; - const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -289,65 +276,45 @@ kernel void kernel_mul_mat_q4_0_f32( const uint nth = tptg.x*tptg.y; const uint ith = tptg.y*tpitg.x + tpitg.y; - const int ix = tpitg.y/4; // 0 or 1 - const int iy = tpitg.y - 4*ix; // 0...3 + sum[ith] = 0.0f; - const int first = 4 * iy; + for (int i = tpitg.x; i < nb; i += tptg.x) { + device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs; + device const float4 * y0p = (device const float4 *) (y + i*QK4_0); - float sumf = 0; + const float d = (float)((x + i)->d); - for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { + const uchar4 x0v = *(x0p + tpitg.y); + const float4 y0v = *(y0p + tpitg.y + 0); + const float4 y1v = *(y0p + tpitg.y + 4); - const float d = (float)x[i].d; - - device const uint8_t * xl = x[i].qs + first; - device const float * yl = y + i * QK4_0 + first; - - float2 acc = {0.0f, 0.0f}; + float acc = 0.0f; for (int j = 0; j < 4; ++j) { + const int x0 = x0v[j] & 0x0F; + const int x1 = x0v[j] >> 4; - acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8); - acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8); + const float y0 = y0v[j]; + const float y1 = y1v[j]; + acc += (x0 - 8)*y0 + (x1 - 8)*y1; } - sumf += d * (acc[0] + acc[1]); + sum[ith] += acc*d; } - sum[ith] = sumf; + // accumulate the sum from all threads in the threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint i = nth/2; i > 0; i /= 2) { + if (ith < i) { + sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } - // - // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } - - //// accumulate the sum from all threads in the threadgroup - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = nth/2; i > 0; i /= 2) { - // if (ith < i) { - // sum[ith] += sum[ith + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - //if (ith == 0) { - // dst[r1*ne0 + r0] = sum[0]; - //} } kernel void kernel_mul_mat_f16_f32( @@ -371,7 +338,6 @@ kernel void kernel_mul_mat_f16_f32( uint3 tpig[[thread_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]]) { - const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; @@ -537,474 +503,3 @@ kernel void kernel_cpy_f32_f32( dst_data[i00] = src[0]; } } - -//============================================ k-quants ====================================================== - -#define QK_K 256 - -typedef struct { - uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits - uint8_t qs[QK_K/4]; // quants - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins -} block_q2_k; - -typedef struct { - half d; // super-block scale for quantized scales - half dmin; // super-block scale for quantized mins - uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits - uint8_t qs[QK_K/2]; // 4--bit quants -} block_q4_k; - -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_k; - -static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { - uchar4 r; - if (j < 4) { - r[0] = q[j+0] & 63; r[1] = q[j+4] & 63; - r[2] = q[j+1] & 63; r[3] = q[j+5] & 63; - } else { - r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); - r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); - } - return r; -} - -//========================================== dequantization ============================= - -static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d = x[i].d; - const float min = x[i].dmin; - - device const uint8_t * q = x[i].qs; - - int is = 0; - float dl, ml; - for (int n = 0; n < QK_K; n += 128) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - - uint8_t sc = x[i].scales[is++]; - dl = d * (sc & 0xF); ml = min * (sc >> 4); - for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; - - sc = x[i].scales[is++]; - dl = d * (sc & 0xF); ml = min * (sc >> 4); - for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; - - shift += 2; - } - q += 32; - } - - } -} - -static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - const float d = x[i].d; - const float min = x[i].dmin; - - device const uint8_t * q = x[i].qs; - device const uint8_t * scales = x[i].scales; - - int is = 0; - for (int j = 0; j < QK_K; j += 64) { - const uchar4 sc = get_scale_min_k4(is, scales); - const float d1 = d * sc[0]; const float m1 = min * sc[1]; - const float d2 = d * sc[2]; const float m2 = min * sc[3]; - for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; - for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; - q += 32; is += 2; - } - - } -} - -static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { - assert(k % QK_K == 0); - const int nb = k / QK_K; - - for (int i = 0; i < nb; i++) { - - device const uint8_t * ql = x[i].ql; - device const uint8_t * qh = x[i].qh; - device const int8_t * sc = x[i].scales; - - const float d = x[i].d; - - for (int n = 0; n < QK_K; n += 128) { - for (int l = 0; l < 32; ++l) { - int is = l/16; - const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - y[l + 0] = d * sc[is + 0] * q1; - y[l + 32] = d * sc[is + 2] * q2; - y[l + 64] = d * sc[is + 4] * q3; - y[l + 96] = d * sc[is + 6] * q4; - } - y += 128; - ql += 64; - qh += 32; - sc += 8; - } - } -} - -kernel void kernel_get_rows_q2_k( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q2_k( - (device const block_q2_k *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q4_k( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q4_k( - (device const block_q4_k *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -kernel void kernel_get_rows_q6_k( - device const void * src0, - device const int * src1, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant uint64_t & nb1, - uint tpig[[thread_position_in_grid]]) { - const int i = tpig; - const int r = ((device int32_t *) src1)[i]; - - dequantize_row_q6_k( - (device const block_q6_k *) ((device char *) src0 + r*nb01), - (device float *) ((device char *) dst + i*nb1), ne00); -} - -//====================================== dot products ========================= - -kernel void kernel_mul_mat_q2_k_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], - uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], // we don't use this for now - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; - - - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid%4; // 0...3 - const int ip = il/2; // 0 or 1 - const int shift1 = 4*(il%2);// 0 or 4 - const int shift2 = shift1+2;// 2 or 6 - const int n = 8; - const int is = 4*il + (n*ir)/16; - - sum[ith] = 0.0f; - - float sumf = 0; - for (int i = tpitg.x; i < nb; i += tptg.x) { - - device const uint8_t * q = x[i].qs + 32*ip + n*ir; - device const uint8_t * scales = x[i].scales + is; - - uint8_t d1 = scales[0] & 0xF; - uint8_t m1 = scales[0] >> 4; - uint8_t d2 = scales[2] & 0xF; - uint8_t m2 = scales[2] >> 4; - - device const float * y = yy + i*QK_K + 64*il + n*ir; - - const float dall = (float)x[i].d; - const float dmin = (float)x[i].dmin; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0]; - s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32]; - } - sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2); - - - } - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; - } - - //// accumulate the sum from all threads in the threadgroup - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = nth/2; i > 0; i /= 2) { - // if (ith < i) { - // sum[ith] += sum[ith + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - //if (ith == 0) { - // dst[r1*ne0 + r0] = sum[0]; - //} -} - -kernel void kernel_mul_mat_q4_k_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], - uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], // we don't use this for now - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - const uint nth = tptg.x*tptg.y; - const uint ith = tptg.y*tpitg.x + tpitg.y; - - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid%4; // 0...3 - const int n = 8; - const int is = 2*il; - - sum[ith] = 0.0f; - - float sumf = 0; - for (int i = tpitg.x; i < nb; i += tptg.x) { - - device const uint8_t * q = (x + i)->qs + 32*il + n*ir; - device const float * y = yy + i*QK_K + 64*il + n*ir; - device const uint8_t * scales = (x + i)->scales; - - const float dall = (float)((x + i)->d); - const float dmin = (float)((x + i)->dmin); - - const uchar4 sc = get_scale_min_k4(is, scales); - - float4 s = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0]; - s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32]; - } - sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]); - - } - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; - } - - //// accumulate the sum from all threads in the threadgroup - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = nth/2; i > 0; i /= 2) { - // if (ith < i) { - // sum[ith] += sum[ith + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - //if (ith == 0) { - // dst[r1*ne0 + r0] = sum[0]; - //} -} - -kernel void kernel_mul_mat_q6_k_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], - uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], // we don't use this for now - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { - - const uint8_t kmask1 = 0x03; - const uint8_t kmask2 = 0x0C; - const uint8_t kmask3 = 0x30; - const uint8_t kmask4 = 0xC0; - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - const uint nth = tptg.x*tptg.y; - const uint ith = tptg.y*tpitg.x + tpitg.y; - - const int step = QK_K / tptg.y; // we expect this to be 16 - const int iqs = step * tpitg.y; // 0...240 in steps of 16 - const int ip = iqs / 128; // 0 or 1 - const int il = (iqs - 128*ip)/16; // 0...7 - const int n = 4; - const int is = 8*ip + (n*il)/16; - - float sumf = 0; - for (int i = tpitg.x; i < nb; i += tptg.x) { - - device const uint8_t * ql = x[i].ql + 64*ip + n*il; - device const uint8_t * qh = x[i].qh + 32*ip + n*il; - device const int8_t * sc = x[i].scales + is; - - device const float * y = yy + i * QK_K + 128*ip + n*il; - - const float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); - } - - sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); - - } - - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; - } - -} diff --git a/ggml.c b/ggml.c index 3b72b80f3..9d4d3583a 100644 --- a/ggml.c +++ b/ggml.c @@ -14729,12 +14729,12 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", + fprintf(fout, "%-6s %-12s %8d %8d %d %d %d %16zu %16zu %16zu %16zu %16p %32s\n", ggml_type_name(tensor->type), ggml_op_name (tensor->op), tensor->n_dims, - ne[0], ne[1], ne[2], ne[3], - nb[0], nb[1], nb[2], nb[3], + (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3], + nb[0], nb[1], nb[2], nb[3], tensor->data, tensor->name); } @@ -14743,13 +14743,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n", + fprintf(fout, "%-6s %-6s %-12s %8d %d %d %d %d %16zu %16zu %16zu %16zu %8d %16p %32s\n", arg, ggml_type_name(tensor->type), ggml_op_name (tensor->op), tensor->n_dims, - ne[0], ne[1], ne[2], ne[3], - nb[0], nb[1], nb[2], nb[3], + (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3], + nb[0], nb[1], nb[2], nb[3], tensor->n_tasks, tensor->data, tensor->name); @@ -14772,11 +14772,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { FILE * fout = stdout; fprintf(fout, "\n"); - fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); - fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); - fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); - fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); - fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval); + fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); + fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); + fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); + fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); + fprintf(fout, "%-16s %8d\n", "eval", (int) size_eval); // header fprintf(fout, "\n"); diff --git a/llama.cpp b/llama.cpp index 1e2c9d767..d80706446 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1028,14 +1028,6 @@ static void llama_model_load_internal( } } - #if defined(GGML_USE_CLBLAST) - if (file_version == LLAMA_FILE_VERSION_GGJT_V3) { - if (hparams.ftype >= LLAMA_FTYPE_MOSTLY_Q2_K && hparams.ftype <= LLAMA_FTYPE_MOSTLY_Q6_K) { - printf("\n===\nK-Quants are currently not supported with CLBlast!!!\nPlease select a q4_0, q4_0, q5_0 or q5_1 format instead!\n=====\n"); - } - } - #endif - if (vocab_only) { return; } From 57b0b53b5457a96afb0c7596d859e54d166cd42f Mon Sep 17 00:00:00 2001 From: SammCheese Date: Fri, 9 Jun 2023 12:39:35 +0200 Subject: [PATCH 8/8] fix kobold lite generation --- koboldcpp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index e7ea0748e..981513987 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -274,7 +274,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): try: return res except Exception as e: - print(f"Generate: Error while generating {e}") + print(f"Generate: Error while generating: {e}") async def send_sse_event(self, event, data): @@ -307,7 +307,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): await asyncio.sleep(0) - await self.wfile.close() + # Implement connection closing here async def handle_request(self, genparams, newprompt, basic_api_flag, stream_flag): @@ -321,7 +321,6 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): try: await asyncio.gather(*tasks) - print("done") generate_result = generate_task.result() return generate_result except Exception as e: @@ -433,6 +432,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): gen = asyncio.run(self.handle_request(genparams, newprompt, basic_api_flag, kai_sse_stream_flag)) try: + self.send_response(200) + self.end_headers() self.wfile.write(json.dumps(gen).encode()) except: print("Generate: The response could not be sent, maybe connection was terminated?")