added support for lora base
This commit is contained in:
parent
375540837e
commit
66a3f4e421
4 changed files with 30 additions and 8 deletions
|
@ -32,6 +32,7 @@ extern "C"
|
||||||
{
|
{
|
||||||
std::string model = inputs.model_filename;
|
std::string model = inputs.model_filename;
|
||||||
lora_filename = inputs.lora_filename;
|
lora_filename = inputs.lora_filename;
|
||||||
|
lora_base = inputs.lora_base;
|
||||||
|
|
||||||
int forceversion = inputs.forceversion;
|
int forceversion = inputs.forceversion;
|
||||||
|
|
||||||
|
|
2
expose.h
2
expose.h
|
@ -11,6 +11,7 @@ struct load_model_inputs
|
||||||
const char * executable_path;
|
const char * executable_path;
|
||||||
const char * model_filename;
|
const char * model_filename;
|
||||||
const char * lora_filename;
|
const char * lora_filename;
|
||||||
|
const char * lora_base;
|
||||||
const bool use_mmap;
|
const bool use_mmap;
|
||||||
const bool use_mlock;
|
const bool use_mlock;
|
||||||
const bool use_smartcontext;
|
const bool use_smartcontext;
|
||||||
|
@ -49,5 +50,6 @@ struct generation_outputs
|
||||||
|
|
||||||
extern std::string executable_path;
|
extern std::string executable_path;
|
||||||
extern std::string lora_filename;
|
extern std::string lora_filename;
|
||||||
|
extern std::string lora_base;
|
||||||
extern std::vector<std::string> generated_tokens;
|
extern std::vector<std::string> generated_tokens;
|
||||||
extern bool generation_finished;
|
extern bool generation_finished;
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
//shared
|
//shared
|
||||||
std::string executable_path = "";
|
std::string executable_path = "";
|
||||||
std::string lora_filename = "";
|
std::string lora_filename = "";
|
||||||
|
std::string lora_base = "";
|
||||||
bool generation_finished;
|
bool generation_finished;
|
||||||
std::vector<std::string> generated_tokens;
|
std::vector<std::string> generated_tokens;
|
||||||
|
|
||||||
|
@ -341,9 +342,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
{
|
{
|
||||||
printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str());
|
printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str());
|
||||||
|
|
||||||
|
const char * lora_base_arg = NULL;
|
||||||
|
if (lora_base != "") {
|
||||||
|
printf("Using LORA base model: %s\n", lora_base.c_str());
|
||||||
|
lora_base_arg = lora_base.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
int err = llama_v2_apply_lora_from_file(llama_ctx_v2,
|
int err = llama_v2_apply_lora_from_file(llama_ctx_v2,
|
||||||
lora_filename.c_str(),
|
lora_filename.c_str(),
|
||||||
NULL,
|
lora_base_arg,
|
||||||
n_threads);
|
n_threads);
|
||||||
if (err != 0)
|
if (err != 0)
|
||||||
{
|
{
|
||||||
|
|
26
koboldcpp.py
26
koboldcpp.py
|
@ -19,6 +19,7 @@ class load_model_inputs(ctypes.Structure):
|
||||||
("executable_path", ctypes.c_char_p),
|
("executable_path", ctypes.c_char_p),
|
||||||
("model_filename", ctypes.c_char_p),
|
("model_filename", ctypes.c_char_p),
|
||||||
("lora_filename", ctypes.c_char_p),
|
("lora_filename", ctypes.c_char_p),
|
||||||
|
("lora_base", ctypes.c_char_p),
|
||||||
("use_mmap", ctypes.c_bool),
|
("use_mmap", ctypes.c_bool),
|
||||||
("use_mlock", ctypes.c_bool),
|
("use_mlock", ctypes.c_bool),
|
||||||
("use_smartcontext", ctypes.c_bool),
|
("use_smartcontext", ctypes.c_bool),
|
||||||
|
@ -146,7 +147,6 @@ def init_library():
|
||||||
def load_model(model_filename):
|
def load_model(model_filename):
|
||||||
inputs = load_model_inputs()
|
inputs = load_model_inputs()
|
||||||
inputs.model_filename = model_filename.encode("UTF-8")
|
inputs.model_filename = model_filename.encode("UTF-8")
|
||||||
inputs.lora_filename = args.lora.encode("UTF-8")
|
|
||||||
inputs.batch_size = 8
|
inputs.batch_size = 8
|
||||||
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
|
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
|
||||||
inputs.threads = args.threads
|
inputs.threads = args.threads
|
||||||
|
@ -154,8 +154,13 @@ def load_model(model_filename):
|
||||||
inputs.f16_kv = True
|
inputs.f16_kv = True
|
||||||
inputs.use_mmap = (not args.nommap)
|
inputs.use_mmap = (not args.nommap)
|
||||||
inputs.use_mlock = args.usemlock
|
inputs.use_mlock = args.usemlock
|
||||||
if args.lora and args.lora!="":
|
inputs.lora_filename = ""
|
||||||
|
inputs.lora_base = ""
|
||||||
|
if args.lora:
|
||||||
|
inputs.lora_filename = args.lora[0].encode("UTF-8")
|
||||||
inputs.use_mmap = False
|
inputs.use_mmap = False
|
||||||
|
if len(args.lora) > 1:
|
||||||
|
inputs.lora_base = args.lora[1].encode("UTF-8")
|
||||||
inputs.use_smartcontext = args.smartcontext
|
inputs.use_smartcontext = args.smartcontext
|
||||||
inputs.unban_tokens = args.unbantokens
|
inputs.unban_tokens = args.unbantokens
|
||||||
inputs.blasbatchsize = args.blasbatchsize
|
inputs.blasbatchsize = args.blasbatchsize
|
||||||
|
@ -744,13 +749,20 @@ def main(args):
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
|
|
||||||
if args.lora and args.lora!="":
|
if args.lora and args.lora[0]!="":
|
||||||
if not os.path.exists(args.lora):
|
if not os.path.exists(args.lora[0]):
|
||||||
print(f"Cannot find lora file: {args.lora}")
|
print(f"Cannot find lora file: {args.lora[0]}")
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
sys.exit(2)
|
sys.exit(2)
|
||||||
else:
|
else:
|
||||||
args.lora = os.path.abspath(args.lora)
|
args.lora[0] = os.path.abspath(args.lora[0])
|
||||||
|
if len(args.lora) > 1:
|
||||||
|
if not os.path.exists(args.lora[1]):
|
||||||
|
print(f"Cannot find lora base: {args.lora[1]}")
|
||||||
|
time.sleep(2)
|
||||||
|
sys.exit(2)
|
||||||
|
else:
|
||||||
|
args.lora[1] = os.path.abspath(args.lora[1])
|
||||||
|
|
||||||
if args.psutil_set_threads:
|
if args.psutil_set_threads:
|
||||||
import psutil
|
import psutil
|
||||||
|
@ -807,7 +819,7 @@ if __name__ == '__main__':
|
||||||
portgroup.add_argument("port_param", help="Port to listen on (positional)", default=defaultport, nargs="?", type=int, action='store')
|
portgroup.add_argument("port_param", help="Port to listen on (positional)", default=defaultport, nargs="?", type=int, action='store')
|
||||||
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
|
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
|
||||||
parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true')
|
parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true')
|
||||||
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", default="")
|
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", metavar=('[lora_filename]', '[lora_base]'), nargs='+')
|
||||||
physical_core_limit = 1
|
physical_core_limit = 1
|
||||||
if os.cpu_count()!=None and os.cpu_count()>1:
|
if os.cpu_count()!=None and os.cpu_count()>1:
|
||||||
physical_core_limit = int(os.cpu_count()/2)
|
physical_core_limit = int(os.cpu_count()/2)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue