Created a python bindings for llama.cpp and emulated a simple Kobold HTTP API Endpoint
This commit is contained in:
parent
a19b5a4adc
commit
2c8f870f53
6 changed files with 414 additions and 1 deletions
5
Makefile
5
Makefile
|
@ -176,7 +176,7 @@ $(info I CC: $(CCV))
|
||||||
$(info I CXX: $(CXXV))
|
$(info I CXX: $(CXXV))
|
||||||
$(info )
|
$(info )
|
||||||
|
|
||||||
default: main quantize
|
default: main llamalib quantize
|
||||||
|
|
||||||
#
|
#
|
||||||
# Build library
|
# Build library
|
||||||
|
@ -194,6 +194,9 @@ clean:
|
||||||
main: main.cpp ggml.o utils.o
|
main: main.cpp ggml.o utils.o
|
||||||
$(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o -o main $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) main.cpp ggml.o utils.o -o main $(LDFLAGS)
|
||||||
./main -h
|
./main -h
|
||||||
|
|
||||||
|
llamalib: expose.cpp ggml.o utils.o
|
||||||
|
$(CXX) $(CXXFLAGS) expose.cpp ggml.o utils.o -shared -o llamalib.dll $(LDFLAGS)
|
||||||
|
|
||||||
quantize: quantize.cpp ggml.o utils.o
|
quantize: quantize.cpp ggml.o utils.o
|
||||||
$(CXX) $(CXXFLAGS) quantize.cpp ggml.o utils.o -o quantize $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) quantize.cpp ggml.o utils.o -o quantize $(LDFLAGS)
|
||||||
|
|
165
expose.cpp
Normal file
165
expose.cpp
Normal file
|
@ -0,0 +1,165 @@
|
||||||
|
//This is Concedo's shitty adapter for adding python bindings for llama
|
||||||
|
|
||||||
|
//Considerations:
|
||||||
|
//Don't want to use pybind11 due to dependencies on MSVCC
|
||||||
|
//ZERO or MINIMAL changes as possible to main.cpp - do not move their function declarations here!
|
||||||
|
//Leave main.cpp UNTOUCHED, We want to be able to update the repo and pull any changes automatically.
|
||||||
|
//No dynamic memory allocation! Setup structs with FIXED (known) shapes and sizes for ALL output fields
|
||||||
|
//Python will ALWAYS provide the memory, we just write to it.
|
||||||
|
|
||||||
|
#include "main.cpp"
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
struct load_model_inputs
|
||||||
|
{
|
||||||
|
const int threads;
|
||||||
|
const int max_context_length;
|
||||||
|
const int batch_size;
|
||||||
|
const char * model_filename;
|
||||||
|
};
|
||||||
|
struct generation_inputs
|
||||||
|
{
|
||||||
|
const int seed;
|
||||||
|
const char * prompt;
|
||||||
|
const int max_length;
|
||||||
|
const float temperature;
|
||||||
|
const int top_k;
|
||||||
|
const float top_p;
|
||||||
|
const float rep_pen;
|
||||||
|
const int rep_pen_range;
|
||||||
|
};
|
||||||
|
struct generation_outputs
|
||||||
|
{
|
||||||
|
int status;
|
||||||
|
char text[16384]; //16kb should be enough for any response
|
||||||
|
};
|
||||||
|
|
||||||
|
gpt_params api_params;
|
||||||
|
gpt_vocab api_vocab;
|
||||||
|
llama_model api_model;
|
||||||
|
int api_n_past = 0;
|
||||||
|
std::vector<float> api_logits;
|
||||||
|
|
||||||
|
bool load_model(const load_model_inputs inputs)
|
||||||
|
{
|
||||||
|
api_params.n_threads = inputs.threads;
|
||||||
|
api_params.n_ctx = inputs.max_context_length;
|
||||||
|
api_params.n_batch = inputs.batch_size;
|
||||||
|
api_params.model = inputs.model_filename;
|
||||||
|
|
||||||
|
if (!llama_model_load(api_params.model, api_model, api_vocab, api_params.n_ctx)) {
|
||||||
|
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, api_params.model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
generation_outputs generate(const generation_inputs inputs, generation_outputs output)
|
||||||
|
{
|
||||||
|
api_params.prompt = inputs.prompt;
|
||||||
|
api_params.seed = inputs.seed;
|
||||||
|
api_params.n_predict = inputs.max_length;
|
||||||
|
api_params.top_k = inputs.top_k;
|
||||||
|
api_params.top_p = inputs.top_p;
|
||||||
|
api_params.temp = inputs.temperature;
|
||||||
|
api_params.repeat_last_n = inputs.rep_pen_range;
|
||||||
|
api_params.repeat_penalty = inputs.rep_pen;
|
||||||
|
|
||||||
|
if (api_params.seed < 0)
|
||||||
|
{
|
||||||
|
api_params.seed = time(NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
api_params.prompt.insert(0, 1, ' ');
|
||||||
|
// tokenize the prompt
|
||||||
|
std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(api_vocab, api_params.prompt, true);
|
||||||
|
api_params.n_predict = std::min(api_params.n_predict, api_model.hparams.n_ctx - (int)embd_inp.size());
|
||||||
|
std::vector<gpt_vocab::id> embd;
|
||||||
|
size_t mem_per_token = 0;
|
||||||
|
llama_eval(api_model, api_params.n_threads, 0, {0, 1, 2, 3}, api_logits, mem_per_token);
|
||||||
|
|
||||||
|
int last_n_size = api_params.repeat_last_n;
|
||||||
|
std::vector<gpt_vocab::id> last_n_tokens(last_n_size);
|
||||||
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
|
int remaining_tokens = api_params.n_predict;
|
||||||
|
int input_consumed = 0;
|
||||||
|
std::mt19937 api_rng(api_params.seed);
|
||||||
|
|
||||||
|
std::string concat_output = "";
|
||||||
|
|
||||||
|
while (remaining_tokens > 0)
|
||||||
|
{
|
||||||
|
gpt_vocab::id id = 0;
|
||||||
|
// predict
|
||||||
|
if (embd.size() > 0)
|
||||||
|
{
|
||||||
|
|
||||||
|
if (!llama_eval(api_model, api_params.n_threads, api_n_past, embd, api_logits, mem_per_token))
|
||||||
|
{
|
||||||
|
fprintf(stderr, "Failed to predict\n");
|
||||||
|
_snprintf_s(output.text,sizeof(output.text),_TRUNCATE,"%s","");
|
||||||
|
output.status = 0;
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
api_n_past += embd.size();
|
||||||
|
embd.clear();
|
||||||
|
|
||||||
|
if (embd_inp.size() <= input_consumed)
|
||||||
|
{
|
||||||
|
// out of user input, sample next token
|
||||||
|
const float top_k = api_params.top_k;
|
||||||
|
const float top_p = api_params.top_p;
|
||||||
|
const float temp = api_params.temp;
|
||||||
|
const float repeat_penalty = api_params.repeat_penalty;
|
||||||
|
const int n_vocab = api_model.hparams.n_vocab;
|
||||||
|
|
||||||
|
{
|
||||||
|
// set the logit of the eos token (2) to zero to avoid sampling it
|
||||||
|
api_logits[api_logits.size() - n_vocab + 2] = 0;
|
||||||
|
//set logits of opening square bracket to zero.
|
||||||
|
api_logits[api_logits.size() - n_vocab + 518] = 0;
|
||||||
|
api_logits[api_logits.size() - n_vocab + 29961] = 0;
|
||||||
|
|
||||||
|
|
||||||
|
id = llama_sample_top_p_top_k(api_vocab, api_logits.data() + (api_logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, api_rng);
|
||||||
|
|
||||||
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
|
last_n_tokens.push_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add it to the context
|
||||||
|
embd.push_back(id);
|
||||||
|
|
||||||
|
// decrement remaining sampling budget
|
||||||
|
--remaining_tokens;
|
||||||
|
|
||||||
|
concat_output += api_vocab.id_to_token[id].c_str();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// some user input remains from prompt or interaction, forward it to processing
|
||||||
|
while (embd_inp.size() > input_consumed)
|
||||||
|
{
|
||||||
|
embd.push_back(embd_inp[input_consumed]);
|
||||||
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
|
last_n_tokens.push_back(embd_inp[input_consumed]);
|
||||||
|
++input_consumed;
|
||||||
|
if (embd.size() > api_params.n_batch)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("output: %s",concat_output.c_str());
|
||||||
|
output.status = 1;
|
||||||
|
_snprintf_s(output.text,sizeof(output.text),_TRUNCATE,"%s",concat_output.c_str());
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
}
|
245
llama_for_kobold.py
Normal file
245
llama_for_kobold.py
Normal file
|
@ -0,0 +1,245 @@
|
||||||
|
# A hacky little script from Concedo that exposes llama.cpp function bindings
|
||||||
|
# allowing it to be used via a simulated kobold api endpoint
|
||||||
|
# it's not very usable as there is a fundamental flaw with llama.cpp
|
||||||
|
# which causes generation delay to scale linearly with original prompt length.
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
import os
|
||||||
|
|
||||||
|
class load_model_inputs(ctypes.Structure):
|
||||||
|
_fields_ = [("threads", ctypes.c_int),
|
||||||
|
("max_context_length", ctypes.c_int),
|
||||||
|
("batch_size", ctypes.c_int),
|
||||||
|
("model_filename", ctypes.c_char_p)]
|
||||||
|
|
||||||
|
class generation_inputs(ctypes.Structure):
|
||||||
|
_fields_ = [("seed", ctypes.c_int),
|
||||||
|
("prompt", ctypes.c_char_p),
|
||||||
|
("max_length", ctypes.c_int),
|
||||||
|
("temperature", ctypes.c_float),
|
||||||
|
("top_k", ctypes.c_int),
|
||||||
|
("top_p", ctypes.c_float),
|
||||||
|
("rep_pen", ctypes.c_float),
|
||||||
|
("rep_pen_range", ctypes.c_int)]
|
||||||
|
|
||||||
|
class generation_outputs(ctypes.Structure):
|
||||||
|
_fields_ = [("status", ctypes.c_int),
|
||||||
|
("text", ctypes.c_char * 16384)]
|
||||||
|
|
||||||
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
handle = ctypes.CDLL(dir_path + "/llamalib.dll")
|
||||||
|
|
||||||
|
handle.load_model.argtypes = [load_model_inputs]
|
||||||
|
handle.load_model.restype = ctypes.c_bool
|
||||||
|
handle.generate.argtypes = [generation_inputs]
|
||||||
|
handle.generate.restype = generation_outputs
|
||||||
|
|
||||||
|
def load_model(model_filename,batch_size=8,max_context_length=512,threads=4):
|
||||||
|
inputs = load_model_inputs()
|
||||||
|
inputs.model_filename = model_filename.encode("UTF-8")
|
||||||
|
inputs.batch_size = batch_size
|
||||||
|
inputs.max_context_length = max_context_length
|
||||||
|
inputs.threads = threads
|
||||||
|
ret = handle.load_model(inputs)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def generate(prompt,max_length=20,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1.1,rep_pen_range=128,seed=-1):
|
||||||
|
inputs = generation_inputs()
|
||||||
|
outputs = generation_outputs()
|
||||||
|
inputs.prompt = prompt.encode("UTF-8")
|
||||||
|
inputs.max_length = max_length
|
||||||
|
inputs.temperature = temperature
|
||||||
|
inputs.top_k = top_k
|
||||||
|
inputs.top_p = top_p
|
||||||
|
inputs.rep_pen = rep_pen
|
||||||
|
inputs.rep_pen_range = rep_pen_range
|
||||||
|
inputs.seed = seed
|
||||||
|
ret = handle.generate(inputs,outputs)
|
||||||
|
if(ret.status==1):
|
||||||
|
return ret.text.decode("UTF-8")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
#################################################################
|
||||||
|
### A hacky simple HTTP server simulating a kobold api by Concedo
|
||||||
|
### we are intentionally NOT using flask, because we want MINIMAL dependencies
|
||||||
|
#################################################################
|
||||||
|
import json, http.server, threading, socket, sys, time
|
||||||
|
|
||||||
|
# global vars
|
||||||
|
global modelname
|
||||||
|
modelname = ""
|
||||||
|
maxctx = 1024
|
||||||
|
maxlen = 256
|
||||||
|
modelbusy = False
|
||||||
|
port = 5001
|
||||||
|
|
||||||
|
class ServerRequestHandler(http.server.BaseHTTPRequestHandler):
|
||||||
|
|
||||||
|
sys_version = ""
|
||||||
|
server_version = "ConcedoLlamaForKoboldServer"
|
||||||
|
|
||||||
|
def do_GET(self):
|
||||||
|
if not self.path.endswith('/'):
|
||||||
|
# redirect browser
|
||||||
|
self.send_response(301)
|
||||||
|
self.send_header("Location", self.path + "/")
|
||||||
|
self.end_headers()
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.path.endswith('/api/v1/model/') or self.path.endswith('/api/latest/model/'):
|
||||||
|
self.send_response(200)
|
||||||
|
self.end_headers()
|
||||||
|
global modelname
|
||||||
|
self.wfile.write(json.dumps({"result": modelname }).encode())
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.path.endswith('/api/v1/config/max_length/') or self.path.endswith('/api/latest/config/max_length/'):
|
||||||
|
self.send_response(200)
|
||||||
|
self.end_headers()
|
||||||
|
global maxlen
|
||||||
|
self.wfile.write(json.dumps({"value":maxlen}).encode())
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.path.endswith('/api/v1/config/max_context_length/') or self.path.endswith('/api/latest/config/max_context_length/'):
|
||||||
|
self.send_response(200)
|
||||||
|
self.end_headers()
|
||||||
|
global maxctx
|
||||||
|
self.wfile.write(json.dumps({"value":maxctx}).encode())
|
||||||
|
return
|
||||||
|
|
||||||
|
self.send_response(404)
|
||||||
|
self.end_headers()
|
||||||
|
rp = 'Error: HTTP Server is running, but this endpoint does not exist. Please check the URL.'
|
||||||
|
self.wfile.write(rp.encode())
|
||||||
|
return
|
||||||
|
|
||||||
|
def do_POST(self):
|
||||||
|
content_length = int(self.headers['Content-Length'])
|
||||||
|
body = self.rfile.read(content_length)
|
||||||
|
if self.path.endswith('/api/v1/generate/') or self.path.endswith('/api/latest/generate/'):
|
||||||
|
global modelbusy
|
||||||
|
if modelbusy:
|
||||||
|
self.send_response(503)
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps({"detail": {
|
||||||
|
"msg": "Server is busy; please try again later.",
|
||||||
|
"type": "service_unavailable",
|
||||||
|
}}).encode())
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
modelbusy = True
|
||||||
|
genparams = None
|
||||||
|
try:
|
||||||
|
genparams = json.loads(body)
|
||||||
|
except ValueError as e:
|
||||||
|
self.send_response(503)
|
||||||
|
self.end_headers()
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\nInput: " + json.dumps(genparams))
|
||||||
|
recvtxt = generate(
|
||||||
|
prompt=genparams.get('prompt', ""),
|
||||||
|
max_length=genparams.get('max_length', 50),
|
||||||
|
temperature=genparams.get('temperature', 0.8),
|
||||||
|
top_k=genparams.get('top_k', 100),
|
||||||
|
top_p=genparams.get('top_p', 0.85),
|
||||||
|
rep_pen=genparams.get('rep_pen', 1.1),
|
||||||
|
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||||
|
seed=-1
|
||||||
|
)
|
||||||
|
print("\nOutput: " + recvtxt)
|
||||||
|
res = {"results": [{"text": recvtxt}]}
|
||||||
|
self.send_response(200)
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(json.dumps(res).encode())
|
||||||
|
modelbusy = False
|
||||||
|
return
|
||||||
|
self.send_response(404)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def do_OPTIONS(self):
|
||||||
|
self.send_response(200)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def do_HEAD(self):
|
||||||
|
self.send_response(200)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def end_headers(self):
|
||||||
|
self.send_header('Access-Control-Allow-Origin', '*')
|
||||||
|
self.send_header('Access-Control-Allow-Methods', '*')
|
||||||
|
self.send_header('Access-Control-Allow-Headers', '*')
|
||||||
|
self.send_header('Content-type', 'application/json')
|
||||||
|
return super(ServerRequestHandler, self).end_headers()
|
||||||
|
|
||||||
|
|
||||||
|
def RunServerMultiThreaded(port, HandlerClass = ServerRequestHandler,
|
||||||
|
ServerClass = http.server.HTTPServer):
|
||||||
|
addr = ('', port)
|
||||||
|
sock = socket.socket (socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
sock.bind(addr)
|
||||||
|
sock.listen(5)
|
||||||
|
|
||||||
|
# Start listener threads.
|
||||||
|
class Thread(threading.Thread):
|
||||||
|
def __init__(self, i):
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.i = i
|
||||||
|
self.daemon = True
|
||||||
|
self.start()
|
||||||
|
def run(self):
|
||||||
|
with http.server.HTTPServer(addr, HandlerClass, False) as self.httpd:
|
||||||
|
#print("Thread %s - Web Server is running at http://0.0.0.0:%s" % (self.i, port))
|
||||||
|
try:
|
||||||
|
self.httpd.socket = sock
|
||||||
|
self.httpd.server_bind = self.server_close = lambda self: None
|
||||||
|
self.httpd.serve_forever()
|
||||||
|
except (KeyboardInterrupt,SystemExit):
|
||||||
|
#print("Thread %s - Server Closing" % (self.i))
|
||||||
|
self.httpd.server_close()
|
||||||
|
sys.exit(0)
|
||||||
|
finally:
|
||||||
|
# Clean-up server (close socket, etc.)
|
||||||
|
self.httpd.server_close()
|
||||||
|
sys.exit(0)
|
||||||
|
def stop(self):
|
||||||
|
self.httpd.server_close()
|
||||||
|
|
||||||
|
numThreads = 5
|
||||||
|
threadArr = []
|
||||||
|
for i in range(numThreads):
|
||||||
|
threadArr.append(Thread(i))
|
||||||
|
while 1:
|
||||||
|
try:
|
||||||
|
time.sleep(2000)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
for i in range(numThreads):
|
||||||
|
threadArr[i].stop()
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# total arguments
|
||||||
|
argc = len(sys.argv)
|
||||||
|
|
||||||
|
if argc<2:
|
||||||
|
print("Usage: " + sys.argv[0] + " model_file_q4_0.bin [port]")
|
||||||
|
exit()
|
||||||
|
if argc>=3:
|
||||||
|
port = int(sys.argv[2])
|
||||||
|
|
||||||
|
if not os.path.exists(sys.argv[1]):
|
||||||
|
print("Cannot find model file: " + sys.argv[1])
|
||||||
|
exit()
|
||||||
|
|
||||||
|
modelname = os.path.abspath(sys.argv[1])
|
||||||
|
print("Loading model: " + modelname)
|
||||||
|
loadok = load_model(modelname,128,maxctx,4)
|
||||||
|
print("Load Model OK: " + str(loadok))
|
||||||
|
|
||||||
|
if loadok:
|
||||||
|
print("Starting Kobold HTTP Server on port " + str(port))
|
||||||
|
print("Please connect to custom endpoint at http://localhost:"+str(port))
|
||||||
|
RunServerMultiThreaded(port)
|
||||||
|
|
BIN
llamalib.dll
Normal file
BIN
llamalib.dll
Normal file
Binary file not shown.
BIN
main.exe
Normal file
BIN
main.exe
Normal file
Binary file not shown.
BIN
quantize.exe
Normal file
BIN
quantize.exe
Normal file
Binary file not shown.
Loading…
Add table
Add a link
Reference in a new issue