diff --git a/Makefile b/Makefile index 1601079a4..d64f65a4b 100644 --- a/Makefile +++ b/Makefile @@ -176,7 +176,7 @@ $(info I CC: $(CCV)) $(info I CXX: $(CXXV)) $(info ) -default: main quantize +default: main llamalib quantize # # Build library @@ -194,6 +194,9 @@ clean: main: main.cpp ggml.o utils.o $(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o -o main $(LDFLAGS) ./main -h + +llamalib: expose.cpp ggml.o utils.o + $(CXX) $(CXXFLAGS) expose.cpp ggml.o utils.o -shared -o llamalib.dll $(LDFLAGS) quantize: quantize.cpp ggml.o utils.o $(CXX) $(CXXFLAGS) quantize.cpp ggml.o utils.o -o quantize $(LDFLAGS) diff --git a/expose.cpp b/expose.cpp new file mode 100644 index 000000000..0ca6b67d8 --- /dev/null +++ b/expose.cpp @@ -0,0 +1,165 @@ +//This is Concedo's shitty adapter for adding python bindings for llama + +//Considerations: +//Don't want to use pybind11 due to dependencies on MSVCC +//ZERO or MINIMAL changes as possible to main.cpp - do not move their function declarations here! +//Leave main.cpp UNTOUCHED, We want to be able to update the repo and pull any changes automatically. +//No dynamic memory allocation! Setup structs with FIXED (known) shapes and sizes for ALL output fields +//Python will ALWAYS provide the memory, we just write to it. + +#include "main.cpp" + +extern "C" { + + struct load_model_inputs + { + const int threads; + const int max_context_length; + const int batch_size; + const char * model_filename; + }; + struct generation_inputs + { + const int seed; + const char * prompt; + const int max_length; + const float temperature; + const int top_k; + const float top_p; + const float rep_pen; + const int rep_pen_range; + }; + struct generation_outputs + { + int status; + char text[16384]; //16kb should be enough for any response + }; + + gpt_params api_params; + gpt_vocab api_vocab; + llama_model api_model; + int api_n_past = 0; + std::vector api_logits; + + bool load_model(const load_model_inputs inputs) + { + api_params.n_threads = inputs.threads; + api_params.n_ctx = inputs.max_context_length; + api_params.n_batch = inputs.batch_size; + api_params.model = inputs.model_filename; + + if (!llama_model_load(api_params.model, api_model, api_vocab, api_params.n_ctx)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, api_params.model.c_str()); + return false; + } + + return true; + } + + generation_outputs generate(const generation_inputs inputs, generation_outputs output) + { + api_params.prompt = inputs.prompt; + api_params.seed = inputs.seed; + api_params.n_predict = inputs.max_length; + api_params.top_k = inputs.top_k; + api_params.top_p = inputs.top_p; + api_params.temp = inputs.temperature; + api_params.repeat_last_n = inputs.rep_pen_range; + api_params.repeat_penalty = inputs.rep_pen; + + if (api_params.seed < 0) + { + api_params.seed = time(NULL); + } + + api_params.prompt.insert(0, 1, ' '); + // tokenize the prompt + std::vector embd_inp = ::llama_tokenize(api_vocab, api_params.prompt, true); + api_params.n_predict = std::min(api_params.n_predict, api_model.hparams.n_ctx - (int)embd_inp.size()); + std::vector embd; + size_t mem_per_token = 0; + llama_eval(api_model, api_params.n_threads, 0, {0, 1, 2, 3}, api_logits, mem_per_token); + + int last_n_size = api_params.repeat_last_n; + std::vector last_n_tokens(last_n_size); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + int remaining_tokens = api_params.n_predict; + int input_consumed = 0; + std::mt19937 api_rng(api_params.seed); + + std::string concat_output = ""; + + while (remaining_tokens > 0) + { + gpt_vocab::id id = 0; + // predict + if (embd.size() > 0) + { + + if (!llama_eval(api_model, api_params.n_threads, api_n_past, embd, api_logits, mem_per_token)) + { + fprintf(stderr, "Failed to predict\n"); + _snprintf_s(output.text,sizeof(output.text),_TRUNCATE,"%s",""); + output.status = 0; + return output; + } + } + + api_n_past += embd.size(); + embd.clear(); + + if (embd_inp.size() <= input_consumed) + { + // out of user input, sample next token + const float top_k = api_params.top_k; + const float top_p = api_params.top_p; + const float temp = api_params.temp; + const float repeat_penalty = api_params.repeat_penalty; + const int n_vocab = api_model.hparams.n_vocab; + + { + // set the logit of the eos token (2) to zero to avoid sampling it + api_logits[api_logits.size() - n_vocab + 2] = 0; + //set logits of opening square bracket to zero. + api_logits[api_logits.size() - n_vocab + 518] = 0; + api_logits[api_logits.size() - n_vocab + 29961] = 0; + + + id = llama_sample_top_p_top_k(api_vocab, api_logits.data() + (api_logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, api_rng); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); + } + + // add it to the context + embd.push_back(id); + + // decrement remaining sampling budget + --remaining_tokens; + + concat_output += api_vocab.id_to_token[id].c_str(); + } + else + { + // some user input remains from prompt or interaction, forward it to processing + while (embd_inp.size() > input_consumed) + { + embd.push_back(embd_inp[input_consumed]); + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(embd_inp[input_consumed]); + ++input_consumed; + if (embd.size() > api_params.n_batch) + { + break; + } + } + } + + } + + printf("output: %s",concat_output.c_str()); + output.status = 1; + _snprintf_s(output.text,sizeof(output.text),_TRUNCATE,"%s",concat_output.c_str()); + return output; + } +} \ No newline at end of file diff --git a/llama_for_kobold.py b/llama_for_kobold.py new file mode 100644 index 000000000..333b81094 --- /dev/null +++ b/llama_for_kobold.py @@ -0,0 +1,245 @@ +# A hacky little script from Concedo that exposes llama.cpp function bindings +# allowing it to be used via a simulated kobold api endpoint +# it's not very usable as there is a fundamental flaw with llama.cpp +# which causes generation delay to scale linearly with original prompt length. + +import ctypes +import os + +class load_model_inputs(ctypes.Structure): + _fields_ = [("threads", ctypes.c_int), + ("max_context_length", ctypes.c_int), + ("batch_size", ctypes.c_int), + ("model_filename", ctypes.c_char_p)] + +class generation_inputs(ctypes.Structure): + _fields_ = [("seed", ctypes.c_int), + ("prompt", ctypes.c_char_p), + ("max_length", ctypes.c_int), + ("temperature", ctypes.c_float), + ("top_k", ctypes.c_int), + ("top_p", ctypes.c_float), + ("rep_pen", ctypes.c_float), + ("rep_pen_range", ctypes.c_int)] + +class generation_outputs(ctypes.Structure): + _fields_ = [("status", ctypes.c_int), + ("text", ctypes.c_char * 16384)] + +dir_path = os.path.dirname(os.path.realpath(__file__)) +handle = ctypes.CDLL(dir_path + "/llamalib.dll") + +handle.load_model.argtypes = [load_model_inputs] +handle.load_model.restype = ctypes.c_bool +handle.generate.argtypes = [generation_inputs] +handle.generate.restype = generation_outputs + +def load_model(model_filename,batch_size=8,max_context_length=512,threads=4): + inputs = load_model_inputs() + inputs.model_filename = model_filename.encode("UTF-8") + inputs.batch_size = batch_size + inputs.max_context_length = max_context_length + inputs.threads = threads + ret = handle.load_model(inputs) + return ret + +def generate(prompt,max_length=20,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1): + inputs = generation_inputs() + outputs = generation_outputs() + inputs.prompt = prompt.encode("UTF-8") + inputs.max_length = max_length + inputs.temperature = temperature + inputs.top_k = top_k + inputs.top_p = top_p + inputs.rep_pen = rep_pen + inputs.rep_pen_range = rep_pen_range + inputs.seed = seed + ret = handle.generate(inputs,outputs) + if(ret.status==1): + return ret.text.decode("UTF-8") + return "" + + +################################################################# +### A hacky simple HTTP server simulating a kobold api by Concedo +### we are intentionally NOT using flask, because we want MINIMAL dependencies +################################################################# +import json, http.server, threading, socket, sys, time + +# global vars +global modelname +modelname = "" +maxctx = 1024 +maxlen = 256 +modelbusy = False +port = 5001 + +class ServerRequestHandler(http.server.BaseHTTPRequestHandler): + + sys_version = "" + server_version = "ConcedoLlamaForKoboldServer" + + def do_GET(self): + if not self.path.endswith('/'): + # redirect browser + self.send_response(301) + self.send_header("Location", self.path + "/") + self.end_headers() + return + + if self.path.endswith('/api/v1/model/') or self.path.endswith('/api/latest/model/'): + self.send_response(200) + self.end_headers() + global modelname + self.wfile.write(json.dumps({"result": modelname }).encode()) + return + + if self.path.endswith('/api/v1/config/max_length/') or self.path.endswith('/api/latest/config/max_length/'): + self.send_response(200) + self.end_headers() + global maxlen + self.wfile.write(json.dumps({"value":maxlen}).encode()) + return + + if self.path.endswith('/api/v1/config/max_context_length/') or self.path.endswith('/api/latest/config/max_context_length/'): + self.send_response(200) + self.end_headers() + global maxctx + self.wfile.write(json.dumps({"value":maxctx}).encode()) + return + + self.send_response(404) + self.end_headers() + rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.' + self.wfile.write(rp.encode()) + return + + def do_POST(self): + content_length = int(self.headers['Content-Length']) + body = self.rfile.read(content_length) + if self.path.endswith('/api/v1/generate/') or self.path.endswith('/api/latest/generate/'): + global modelbusy + if modelbusy: + self.send_response(503) + self.end_headers() + self.wfile.write(json.dumps({"detail": { + "msg": "Server is busy; please try again later.", + "type": "service_unavailable", + }}).encode()) + return + else: + modelbusy = True + genparams = None + try: + genparams = json.loads(body) + except ValueError as e: + self.send_response(503) + self.end_headers() + return + + print("\nInput: " + json.dumps(genparams)) + recvtxt = generate( + prompt=genparams.get('prompt', ""), + max_length=genparams.get('max_length', 50), + temperature=genparams.get('temperature', 0.8), + top_k=genparams.get('top_k', 100), + top_p=genparams.get('top_p', 0.85), + rep_pen=genparams.get('rep_pen', 1.1), + rep_pen_range=genparams.get('rep_pen_range', 128), + seed=-1 + ) + print("\nOutput: " + recvtxt) + res = {"results": [{"text": recvtxt}]} + self.send_response(200) + self.end_headers() + self.wfile.write(json.dumps(res).encode()) + modelbusy = False + return + self.send_response(404) + self.end_headers() + + def do_OPTIONS(self): + self.send_response(200) + self.end_headers() + + def do_HEAD(self): + self.send_response(200) + self.end_headers() + + def end_headers(self): + self.send_header('Access-Control-Allow-Origin', '*') + self.send_header('Access-Control-Allow-Methods', '*') + self.send_header('Access-Control-Allow-Headers', '*') + self.send_header('Content-type', 'application/json') + return super(ServerRequestHandler, self).end_headers() + + +def RunServerMultiThreaded(port, HandlerClass = ServerRequestHandler, + ServerClass = http.server.HTTPServer): + addr = ('', port) + sock = socket.socket (socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addr) + sock.listen(5) + + # Start listener threads. + class Thread(threading.Thread): + def __init__(self, i): + threading.Thread.__init__(self) + self.i = i + self.daemon = True + self.start() + def run(self): + with http.server.HTTPServer(addr, HandlerClass, False) as self.httpd: + #print("Thread %s - Web Server is running at http://0.0.0.0:%s" % (self.i, port)) + try: + self.httpd.socket = sock + self.httpd.server_bind = self.server_close = lambda self: None + self.httpd.serve_forever() + except (KeyboardInterrupt,SystemExit): + #print("Thread %s - Server Closing" % (self.i)) + self.httpd.server_close() + sys.exit(0) + finally: + # Clean-up server (close socket, etc.) + self.httpd.server_close() + sys.exit(0) + def stop(self): + self.httpd.server_close() + + numThreads = 5 + threadArr = [] + for i in range(numThreads): + threadArr.append(Thread(i)) + while 1: + try: + time.sleep(2000) + except KeyboardInterrupt: + for i in range(numThreads): + threadArr[i].stop() + sys.exit(0) + +if __name__ == '__main__': + # total arguments + argc = len(sys.argv) + + if argc<2: + print("Usage: " + sys.argv[0] + " model_file_q4_0.bin [port]") + exit() + if argc>=3: + port = int(sys.argv[2]) + + if not os.path.exists(sys.argv[1]): + print("Cannot find model file: " + sys.argv[1]) + exit() + + modelname = os.path.abspath(sys.argv[1]) + print("Loading model: " + modelname) + loadok = load_model(modelname,128,maxctx,4) + print("Load Model OK: " + str(loadok)) + + if loadok: + print("Starting Kobold HTTP Server on port " + str(port)) + print("Please connect to custom endpoint at http://localhost:"+str(port)) + RunServerMultiThreaded(port) + \ No newline at end of file diff --git a/llamalib.dll b/llamalib.dll new file mode 100644 index 000000000..362d635c1 Binary files /dev/null and b/llamalib.dll differ diff --git a/main.exe b/main.exe new file mode 100644 index 000000000..bac27dc7e Binary files /dev/null and b/main.exe differ diff --git a/quantize.exe b/quantize.exe new file mode 100644 index 000000000..348b2e229 Binary files /dev/null and b/quantize.exe differ