added support for lora base

This commit is contained in:
Concedo 2023-06-10 19:29:45 +08:00
parent 375540837e
commit 66a3f4e421
4 changed files with 30 additions and 8 deletions

View file

@ -32,6 +32,7 @@ extern "C"
{
std::string model = inputs.model_filename;
lora_filename = inputs.lora_filename;
lora_base = inputs.lora_base;
int forceversion = inputs.forceversion;

View file

@ -11,6 +11,7 @@ struct load_model_inputs
const char * executable_path;
const char * model_filename;
const char * lora_filename;
const char * lora_base;
const bool use_mmap;
const bool use_mlock;
const bool use_smartcontext;
@ -49,5 +50,6 @@ struct generation_outputs
extern std::string executable_path;
extern std::string lora_filename;
extern std::string lora_base;
extern std::vector<std::string> generated_tokens;
extern bool generation_finished;

View file

@ -31,6 +31,7 @@
//shared
std::string executable_path = "";
std::string lora_filename = "";
std::string lora_base = "";
bool generation_finished;
std::vector<std::string> generated_tokens;
@ -341,9 +342,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
{
printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str());
const char * lora_base_arg = NULL;
if (lora_base != "") {
printf("Using LORA base model: %s\n", lora_base.c_str());
lora_base_arg = lora_base.c_str();
}
int err = llama_v2_apply_lora_from_file(llama_ctx_v2,
lora_filename.c_str(),
NULL,
lora_base_arg,
n_threads);
if (err != 0)
{

View file

@ -19,6 +19,7 @@ class load_model_inputs(ctypes.Structure):
("executable_path", ctypes.c_char_p),
("model_filename", ctypes.c_char_p),
("lora_filename", ctypes.c_char_p),
("lora_base", ctypes.c_char_p),
("use_mmap", ctypes.c_bool),
("use_mlock", ctypes.c_bool),
("use_smartcontext", ctypes.c_bool),
@ -146,7 +147,6 @@ def init_library():
def load_model(model_filename):
inputs = load_model_inputs()
inputs.model_filename = model_filename.encode("UTF-8")
inputs.lora_filename = args.lora.encode("UTF-8")
inputs.batch_size = 8
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
inputs.threads = args.threads
@ -154,8 +154,13 @@ def load_model(model_filename):
inputs.f16_kv = True
inputs.use_mmap = (not args.nommap)
inputs.use_mlock = args.usemlock
if args.lora and args.lora!="":
inputs.lora_filename = ""
inputs.lora_base = ""
if args.lora:
inputs.lora_filename = args.lora[0].encode("UTF-8")
inputs.use_mmap = False
if len(args.lora) > 1:
inputs.lora_base = args.lora[1].encode("UTF-8")
inputs.use_smartcontext = args.smartcontext
inputs.unban_tokens = args.unbantokens
inputs.blasbatchsize = args.blasbatchsize
@ -744,13 +749,20 @@ def main(args):
time.sleep(2)
sys.exit(2)
if args.lora and args.lora!="":
if not os.path.exists(args.lora):
print(f"Cannot find lora file: {args.lora}")
if args.lora and args.lora[0]!="":
if not os.path.exists(args.lora[0]):
print(f"Cannot find lora file: {args.lora[0]}")
time.sleep(2)
sys.exit(2)
else:
args.lora = os.path.abspath(args.lora)
args.lora[0] = os.path.abspath(args.lora[0])
if len(args.lora) > 1:
if not os.path.exists(args.lora[1]):
print(f"Cannot find lora base: {args.lora[1]}")
time.sleep(2)
sys.exit(2)
else:
args.lora[1] = os.path.abspath(args.lora[1])
if args.psutil_set_threads:
import psutil
@ -807,7 +819,7 @@ if __name__ == '__main__':
portgroup.add_argument("port_param", help="Port to listen on (positional)", default=defaultport, nargs="?", type=int, action='store')
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true')
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", default="")
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+')
physical_core_limit = 1
if os.cpu_count()!=None and os.cpu_count()>1:
physical_core_limit = int(os.cpu_count()/2)