added lora support
This commit is contained in:
parent
c454f8b848
commit
6e908c1792
4 changed files with 45 additions and 19 deletions
|
@ -21,6 +21,7 @@
|
|||
#include "model_adapter.cpp"
|
||||
|
||||
std::string executable_path = "";
|
||||
std::string lora_filename = "";
|
||||
|
||||
extern "C"
|
||||
{
|
||||
|
@ -33,6 +34,7 @@ extern "C"
|
|||
bool load_model(const load_model_inputs inputs)
|
||||
{
|
||||
std::string model = inputs.model_filename;
|
||||
lora_filename = inputs.lora_filename;
|
||||
file_format = check_file_format(model.c_str());
|
||||
|
||||
//first digit is whether configured, second is platform, third is devices
|
||||
|
|
9
expose.h
9
expose.h
|
@ -7,9 +7,9 @@ struct load_model_inputs
|
|||
const int max_context_length;
|
||||
const int batch_size;
|
||||
const bool f16_kv;
|
||||
const char *executable_path;
|
||||
const char *model_filename;
|
||||
const int n_parts_overwrite = -1;
|
||||
const char * executable_path;
|
||||
const char * model_filename;
|
||||
const char * lora_filename;
|
||||
const bool use_mmap;
|
||||
const bool use_smartcontext;
|
||||
const int clblast_info = 0;
|
||||
|
@ -34,4 +34,5 @@ struct generation_outputs
|
|||
char text[16384]; //16kb should be enough for any response
|
||||
};
|
||||
|
||||
extern std::string executable_path;
|
||||
extern std::string executable_path;
|
||||
extern std::string lora_filename;
|
|
@ -76,7 +76,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
{
|
||||
llama_ctx_params = llama_context_default_params();
|
||||
llama_ctx_params.n_ctx = inputs.max_context_length;
|
||||
llama_ctx_params.n_parts = -1;//inputs.n_parts_overwrite;
|
||||
llama_ctx_params.n_parts = -1;
|
||||
llama_ctx_params.seed = -1;
|
||||
llama_ctx_params.f16_kv = inputs.f16_kv;
|
||||
llama_ctx_params.logits_all = false;
|
||||
|
@ -95,6 +95,21 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format);
|
||||
}
|
||||
|
||||
if (lora_filename != "")
|
||||
{
|
||||
printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str());
|
||||
|
||||
int err = llama_apply_lora_from_file(llama_ctx_v1,
|
||||
lora_filename.c_str(),
|
||||
NULL,
|
||||
n_threads);
|
||||
if (err != 0)
|
||||
{
|
||||
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
}
|
||||
|
||||
//determine mem per token
|
||||
const std::vector<int> tmp = {0, 1, 2, 3};
|
||||
llama_eval(llama_ctx_v1, tmp.data(), tmp.size(), 0, params.n_threads);
|
||||
|
|
36
koboldcpp.py
36
koboldcpp.py
|
@ -16,7 +16,7 @@ class load_model_inputs(ctypes.Structure):
|
|||
("f16_kv", ctypes.c_bool),
|
||||
("executable_path", ctypes.c_char_p),
|
||||
("model_filename", ctypes.c_char_p),
|
||||
("n_parts_overwrite", ctypes.c_int),
|
||||
("lora_filename", ctypes.c_char_p),
|
||||
("use_mmap", ctypes.c_bool),
|
||||
("use_smartcontext", ctypes.c_bool),
|
||||
("clblast_info", ctypes.c_int),
|
||||
|
@ -89,17 +89,17 @@ def init_library():
|
|||
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
|
||||
handle.generate.restype = generation_outputs
|
||||
|
||||
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False,use_smartcontext=False,blasbatchsize=512):
|
||||
def load_model(model_filename):
|
||||
inputs = load_model_inputs()
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
inputs.batch_size = batch_size
|
||||
inputs.max_context_length = max_context_length #initial value to use for ctx, can be overwritten
|
||||
inputs.threads = threads
|
||||
inputs.n_parts_overwrite = n_parts_overwrite
|
||||
inputs.lora_filename = args.lora.encode("UTF-8")
|
||||
inputs.batch_size = 8
|
||||
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
|
||||
inputs.threads = args.threads
|
||||
inputs.f16_kv = True
|
||||
inputs.use_mmap = use_mmap
|
||||
inputs.use_smartcontext = use_smartcontext
|
||||
inputs.blasbatchsize = blasbatchsize
|
||||
inputs.use_mmap = (not args.nommap)
|
||||
inputs.use_smartcontext = args.smartcontext
|
||||
inputs.blasbatchsize = args.blasbatchsize
|
||||
clblastids = 0
|
||||
if args.useclblast:
|
||||
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
|
||||
|
@ -403,7 +403,7 @@ def main(args):
|
|||
embedded_kailite = None
|
||||
ggml_selected_file = args.model_param
|
||||
if not ggml_selected_file:
|
||||
ggml_selected_file = args.model
|
||||
ggml_selected_file = args.model
|
||||
if not ggml_selected_file:
|
||||
#give them a chance to pick a file
|
||||
print("For command line arguments, please refer to --help")
|
||||
|
@ -430,10 +430,17 @@ def main(args):
|
|||
time.sleep(2)
|
||||
sys.exit(2)
|
||||
|
||||
mdl_nparts = sum(1 for n in range(1, 9) if os.path.exists(f"{ggml_selected_file}.{n}")) + 1
|
||||
if args.lora and args.lora!="":
|
||||
if not os.path.exists(args.lora):
|
||||
print(f"Cannot find lora file: {args.lora}")
|
||||
time.sleep(2)
|
||||
sys.exit(2)
|
||||
else:
|
||||
args.lora = os.path.abspath(args.lora)
|
||||
|
||||
modelname = os.path.abspath(ggml_selected_file)
|
||||
print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}, SmartContext: {args.smartcontext}]")
|
||||
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,(not args.nommap),args.smartcontext,args.blasbatchsize)
|
||||
print(f"Loading model: {modelname} \n[Threads: {args.threads}, SmartContext: {args.smartcontext}]")
|
||||
loadok = load_model(modelname)
|
||||
print("Load Model OK: " + str(loadok))
|
||||
|
||||
if not loadok:
|
||||
|
@ -477,7 +484,8 @@ if __name__ == '__main__':
|
|||
portgroup.add_argument("port_param", help="Port to listen on (positional)", default=defaultport, nargs="?", type=int, action='store')
|
||||
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
|
||||
parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true')
|
||||
|
||||
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", default="")
|
||||
|
||||
#os.environ["OMP_NUM_THREADS"] = '12'
|
||||
# psutil.cpu_count(logical=False)
|
||||
physical_core_limit = 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue