This commit is contained in:
Concedo 2023-03-20 13:37:51 +08:00
parent 474f760411
commit dda69d4034
2 changed files with 11 additions and 7 deletions

View file

@ -16,4 +16,4 @@ If you care, **please contribute to [this discussion](https://github.com/ggergan
## Usage ## Usage
- Windows binaries are provided in the form of **llamacpp.dll** but if you feel worried go ahead and rebuild it yourself. - Windows binaries are provided in the form of **llamacpp.dll** but if you feel worried go ahead and rebuild it yourself.
- Weights are not included, you can use the llama.cpp quantize.exe to generate them from your official weight files (or download them from...places). - Weights are not included, you can use the llama.cpp quantize.exe to generate them from your official weight files (or download them from...places).
- To run, simply clone the repo and run `llama_for_kobold.py [ggml_quant_model.bin] [port]`, and then connect with Kobold or Kobold Lite. - To run, simply clone the repo and run `llama_for_kobold.py [ggml_quant_model.bin] [port]`, and then connect with Kobold or Kobold Lite (for example, https://lite.koboldai.net/?local=1&port=5001).

View file

@ -5,6 +5,7 @@
import ctypes import ctypes
import os import os
#from pathlib import Path
class load_model_inputs(ctypes.Structure): class load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int), _fields_ = [("threads", ctypes.c_int),
@ -33,7 +34,7 @@ handle = ctypes.CDLL(dir_path + "/llamacpp.dll")
handle.load_model.argtypes = [load_model_inputs] handle.load_model.argtypes = [load_model_inputs]
handle.load_model.restype = ctypes.c_bool handle.load_model.restype = ctypes.c_bool
handle.generate.argtypes = [generation_inputs] handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
handle.generate.restype = generation_outputs handle.generate.restype = generation_outputs
def load_model(model_filename,batch_size=8,max_context_length=512,threads=4,n_parts_overwrite=-1): def load_model(model_filename,batch_size=8,max_context_length=512,threads=4,n_parts_overwrite=-1):
@ -71,8 +72,8 @@ def generate(prompt,max_length=20,temperature=0.8,top_k=100,top_p=0.85,rep_pen=1
import json, http.server, threading, socket, sys, time import json, http.server, threading, socket, sys, time
# global vars # global vars
global modelname global friendlymodelname
modelname = "" friendlymodelname = ""
maxctx = 1024 maxctx = 1024
maxlen = 256 maxlen = 256
modelbusy = False modelbusy = False
@ -95,8 +96,8 @@ class ServerRequestHandler(http.server.BaseHTTPRequestHandler):
if self.path.endswith('/api/v1/model/') or self.path.endswith('/api/latest/model/'): if self.path.endswith('/api/v1/model/') or self.path.endswith('/api/latest/model/'):
self.send_response(200) self.send_response(200)
self.end_headers() self.end_headers()
global modelname global friendlymodelname
self.wfile.write(json.dumps({"result": modelname }).encode()) self.wfile.write(json.dumps({"result": friendlymodelname }).encode())
return return
if self.path.endswith('/api/v1/config/max_length/') or self.path.endswith('/api/latest/config/max_length/'): if self.path.endswith('/api/v1/config/max_length/') or self.path.endswith('/api/latest/config/max_length/'):
@ -122,7 +123,7 @@ class ServerRequestHandler(http.server.BaseHTTPRequestHandler):
def do_POST(self): def do_POST(self):
content_length = int(self.headers['Content-Length']) content_length = int(self.headers['Content-Length'])
body = self.rfile.read(content_length) body = self.rfile.read(content_length)
if self.path.endswith('/api/v1/generate/') or self.path.endswith('/api/latest/generate/'): if self.path.endswith('/api/v1/generate/') or self.path.endswith('/api/latest/generate/') or self.path.endswith('/api/v1/generate') or self.path.endswith('/api/latest/generate'):
global modelbusy global modelbusy
global last_context global last_context
if modelbusy: if modelbusy:
@ -257,6 +258,9 @@ if __name__ == '__main__':
loadok = load_model(modelname,24,maxctx,4,mdl_nparts) loadok = load_model(modelname,24,maxctx,4,mdl_nparts)
print("Load Model OK: " + str(loadok)) print("Load Model OK: " + str(loadok))
#friendlymodelname = Path(modelname).stem ### this wont work on local kobold api, so we must hardcode a known HF model name
friendlymodelname = "concedo/llamacpp"
if loadok: if loadok:
print("Starting Kobold HTTP Server on port " + str(port)) print("Starting Kobold HTTP Server on port " + str(port))
print("Please connect to custom endpoint at http://localhost:"+str(port)) print("Please connect to custom endpoint at http://localhost:"+str(port))