added support for lora base
This commit is contained in:
parent
375540837e
commit
66a3f4e421
4 changed files with 30 additions and 8 deletions
|
@ -32,6 +32,7 @@ extern "C"
|
|||
{
|
||||
std::string model = inputs.model_filename;
|
||||
lora_filename = inputs.lora_filename;
|
||||
lora_base = inputs.lora_base;
|
||||
|
||||
int forceversion = inputs.forceversion;
|
||||
|
||||
|
|
2
expose.h
2
expose.h
|
@ -11,6 +11,7 @@ struct load_model_inputs
|
|||
const char * executable_path;
|
||||
const char * model_filename;
|
||||
const char * lora_filename;
|
||||
const char * lora_base;
|
||||
const bool use_mmap;
|
||||
const bool use_mlock;
|
||||
const bool use_smartcontext;
|
||||
|
@ -49,5 +50,6 @@ struct generation_outputs
|
|||
|
||||
extern std::string executable_path;
|
||||
extern std::string lora_filename;
|
||||
extern std::string lora_base;
|
||||
extern std::vector<std::string> generated_tokens;
|
||||
extern bool generation_finished;
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
//shared
|
||||
std::string executable_path = "";
|
||||
std::string lora_filename = "";
|
||||
std::string lora_base = "";
|
||||
bool generation_finished;
|
||||
std::vector<std::string> generated_tokens;
|
||||
|
||||
|
@ -341,9 +342,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
{
|
||||
printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str());
|
||||
|
||||
const char * lora_base_arg = NULL;
|
||||
if (lora_base != "") {
|
||||
printf("Using LORA base model: %s\n", lora_base.c_str());
|
||||
lora_base_arg = lora_base.c_str();
|
||||
}
|
||||
|
||||
int err = llama_v2_apply_lora_from_file(llama_ctx_v2,
|
||||
lora_filename.c_str(),
|
||||
NULL,
|
||||
lora_base_arg,
|
||||
n_threads);
|
||||
if (err != 0)
|
||||
{
|
||||
|
|
26
koboldcpp.py
26
koboldcpp.py
|
@ -19,6 +19,7 @@ class load_model_inputs(ctypes.Structure):
|
|||
("executable_path", ctypes.c_char_p),
|
||||
("model_filename", ctypes.c_char_p),
|
||||
("lora_filename", ctypes.c_char_p),
|
||||
("lora_base", ctypes.c_char_p),
|
||||
("use_mmap", ctypes.c_bool),
|
||||
("use_mlock", ctypes.c_bool),
|
||||
("use_smartcontext", ctypes.c_bool),
|
||||
|
@ -146,7 +147,6 @@ def init_library():
|
|||
def load_model(model_filename):
|
||||
inputs = load_model_inputs()
|
||||
inputs.model_filename = model_filename.encode("UTF-8")
|
||||
inputs.lora_filename = args.lora.encode("UTF-8")
|
||||
inputs.batch_size = 8
|
||||
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
|
||||
inputs.threads = args.threads
|
||||
|
@ -154,8 +154,13 @@ def load_model(model_filename):
|
|||
inputs.f16_kv = True
|
||||
inputs.use_mmap = (not args.nommap)
|
||||
inputs.use_mlock = args.usemlock
|
||||
if args.lora and args.lora!="":
|
||||
inputs.lora_filename = ""
|
||||
inputs.lora_base = ""
|
||||
if args.lora:
|
||||
inputs.lora_filename = args.lora[0].encode("UTF-8")
|
||||
inputs.use_mmap = False
|
||||
if len(args.lora) > 1:
|
||||
inputs.lora_base = args.lora[1].encode("UTF-8")
|
||||
inputs.use_smartcontext = args.smartcontext
|
||||
inputs.unban_tokens = args.unbantokens
|
||||
inputs.blasbatchsize = args.blasbatchsize
|
||||
|
@ -744,13 +749,20 @@ def main(args):
|
|||
time.sleep(2)
|
||||
sys.exit(2)
|
||||
|
||||
if args.lora and args.lora!="":
|
||||
if not os.path.exists(args.lora):
|
||||
print(f"Cannot find lora file: {args.lora}")
|
||||
if args.lora and args.lora[0]!="":
|
||||
if not os.path.exists(args.lora[0]):
|
||||
print(f"Cannot find lora file: {args.lora[0]}")
|
||||
time.sleep(2)
|
||||
sys.exit(2)
|
||||
else:
|
||||
args.lora = os.path.abspath(args.lora)
|
||||
args.lora[0] = os.path.abspath(args.lora[0])
|
||||
if len(args.lora) > 1:
|
||||
if not os.path.exists(args.lora[1]):
|
||||
print(f"Cannot find lora base: {args.lora[1]}")
|
||||
time.sleep(2)
|
||||
sys.exit(2)
|
||||
else:
|
||||
args.lora[1] = os.path.abspath(args.lora[1])
|
||||
|
||||
if args.psutil_set_threads:
|
||||
import psutil
|
||||
|
@ -807,7 +819,7 @@ if __name__ == '__main__':
|
|||
portgroup.add_argument("port_param", help="Port to listen on (positional)", default=defaultport, nargs="?", type=int, action='store')
|
||||
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
|
||||
parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true')
|
||||
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", default="")
|
||||
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+')
|
||||
physical_core_limit = 1
|
||||
if os.cpu_count()!=None and os.cpu_count()>1:
|
||||
physical_core_limit = int(os.cpu_count()/2)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue