added lora support

This commit is contained in:
Concedo 2023-04-22 12:29:38 +08:00
parent c454f8b848
commit 6e908c1792
4 changed files with 45 additions and 19 deletions

View file

@ -21,6 +21,7 @@
#include "model_adapter.cpp"
std::string executable_path = "";
std::string lora_filename = "";
extern "C"
{
@ -33,6 +34,7 @@ extern "C"
bool load_model(const load_model_inputs inputs)
{
std::string model = inputs.model_filename;
lora_filename = inputs.lora_filename;
file_format = check_file_format(model.c_str());
//first digit is whether configured, second is platform, third is devices

View file

@ -7,9 +7,9 @@ struct load_model_inputs
const int max_context_length;
const int batch_size;
const bool f16_kv;
const char *executable_path;
const char *model_filename;
const int n_parts_overwrite = -1;
const char * executable_path;
const char * model_filename;
const char * lora_filename;
const bool use_mmap;
const bool use_smartcontext;
const int clblast_info = 0;
@ -34,4 +34,5 @@ struct generation_outputs
char text[16384]; //16kb should be enough for any response
};
extern std::string executable_path;
extern std::string executable_path;
extern std::string lora_filename;

View file

@ -76,7 +76,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
{
llama_ctx_params = llama_context_default_params();
llama_ctx_params.n_ctx = inputs.max_context_length;
llama_ctx_params.n_parts = -1;//inputs.n_parts_overwrite;
llama_ctx_params.n_parts = -1;
llama_ctx_params.seed = -1;
llama_ctx_params.f16_kv = inputs.f16_kv;
llama_ctx_params.logits_all = false;
@ -95,6 +95,21 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
printf("\n---\nWarning: Your model has an INVALID or OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format);
}
if (lora_filename != "")
{
printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str());
int err = llama_apply_lora_from_file(llama_ctx_v1,
lora_filename.c_str(),
NULL,
n_threads);
if (err != 0)
{
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
return ModelLoadResult::FAIL;
}
}
//determine mem per token
const std::vector<int> tmp = {0, 1, 2, 3};
llama_eval(llama_ctx_v1, tmp.data(), tmp.size(), 0, params.n_threads);

View file

@ -16,7 +16,7 @@ class load_model_inputs(ctypes.Structure):
("f16_kv", ctypes.c_bool),
("executable_path", ctypes.c_char_p),
("model_filename", ctypes.c_char_p),
("n_parts_overwrite", ctypes.c_int),
("lora_filename", ctypes.c_char_p),
("use_mmap", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
("clblast_info", ctypes.c_int),
@ -89,17 +89,17 @@ def init_library():
handle.generate.argtypes = [generation_inputs, ctypes.c_wchar_p] #apparently needed for osx to work. i duno why they need to interpret it that way but whatever
handle.generate.restype = generation_outputs
def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwrite=-1,threads=6,use_mmap=False,use_smartcontext=False,blasbatchsize=512):
def load_model(model_filename):
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.batch_size = batch_size
inputs.max_context_length = max_context_length #initial value to use for ctx, can be overwritten
inputs.threads = threads
inputs.n_parts_overwrite = n_parts_overwrite
inputs.lora_filename = args.lora.encode("UTF-8")
inputs.batch_size = 8
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
inputs.threads = args.threads
inputs.f16_kv = True
inputs.use_mmap = use_mmap
inputs.use_smartcontext = use_smartcontext
inputs.blasbatchsize = blasbatchsize
inputs.use_mmap = (not args.nommap)
inputs.use_smartcontext = args.smartcontext
inputs.blasbatchsize = args.blasbatchsize
clblastids = 0
if args.useclblast:
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
@ -403,7 +403,7 @@ def main(args):
embedded_kailite = None
ggml_selected_file = args.model_param
if not ggml_selected_file:
ggml_selected_file = args.model
ggml_selected_file = args.model
if not ggml_selected_file:
#give them a chance to pick a file
print("For command line arguments, please refer to --help")
@ -430,10 +430,17 @@ def main(args):
time.sleep(2)
sys.exit(2)
mdl_nparts = sum(1 for n in range(1, 9) if os.path.exists(f"{ggml_selected_file}.{n}")) + 1
if args.lora and args.lora!="":
if not os.path.exists(args.lora):
print(f"Cannot find lora file: {args.lora}")
time.sleep(2)
sys.exit(2)
else:
args.lora = os.path.abspath(args.lora)
modelname = os.path.abspath(ggml_selected_file)
print(f"Loading model: {modelname} \n[Parts: {mdl_nparts}, Threads: {args.threads}, SmartContext: {args.smartcontext}]")
loadok = load_model(modelname,8,maxctx,mdl_nparts,args.threads,(not args.nommap),args.smartcontext,args.blasbatchsize)
print(f"Loading model: {modelname} \n[Threads: {args.threads}, SmartContext: {args.smartcontext}]")
loadok = load_model(modelname)
print("Load Model OK: " + str(loadok))
if not loadok:
@ -477,7 +484,8 @@ if __name__ == '__main__':
portgroup.add_argument("port_param", help="Port to listen on (positional)", default=defaultport, nargs="?", type=int, action='store')
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true')
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", default="")
#os.environ["OMP_NUM_THREADS"] = '12'
# psutil.cpu_count(logical=False)
physical_core_limit = 1