added optional force versioning flag
This commit is contained in:
parent
39f3d1cf48
commit
2edbcebe27
3 changed files with 17 additions and 5 deletions
13
expose.cpp
13
expose.cpp
|
@ -35,7 +35,18 @@ extern "C"
|
|||
{
|
||||
std::string model = inputs.model_filename;
|
||||
lora_filename = inputs.lora_filename;
|
||||
file_format = check_file_format(model.c_str());
|
||||
|
||||
int forceversion = inputs.forceversion;
|
||||
|
||||
if(forceversion==0)
|
||||
{
|
||||
file_format = check_file_format(model.c_str());
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\nWARNING: FILE FORMAT FORCED TO VER %d\nIf incorrect, loading may fail or crash.\n",forceversion);
|
||||
file_format = (FileFormat)forceversion;
|
||||
}
|
||||
|
||||
//first digit is whether configured, second is platform, third is devices
|
||||
int parseinfo = inputs.clblast_info;
|
||||
|
|
1
expose.h
1
expose.h
|
@ -18,6 +18,7 @@ struct load_model_inputs
|
|||
const int clblast_info = 0;
|
||||
const int blasbatchsize = 512;
|
||||
const bool debugmode;
|
||||
const int forceversion = 0;
|
||||
};
|
||||
struct generation_inputs
|
||||
{
|
||||
|
|
|
@ -24,7 +24,8 @@ class load_model_inputs(ctypes.Structure):
|
|||
("unban_tokens", ctypes.c_bool),
|
||||
("clblast_info", ctypes.c_int),
|
||||
("blasbatchsize", ctypes.c_int),
|
||||
("debugmode", ctypes.c_bool)]
|
||||
("debugmode", ctypes.c_bool),
|
||||
("forceversion", ctypes.c_int)]
|
||||
|
||||
class generation_inputs(ctypes.Structure):
|
||||
_fields_ = [("seed", ctypes.c_int),
|
||||
|
@ -143,6 +144,7 @@ def load_model(model_filename):
|
|||
inputs.use_smartcontext = args.smartcontext
|
||||
inputs.unban_tokens = args.unbantokens
|
||||
inputs.blasbatchsize = args.blasbatchsize
|
||||
inputs.forceversion = args.forceversion
|
||||
clblastids = 0
|
||||
if args.useclblast:
|
||||
clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1])
|
||||
|
@ -601,9 +603,6 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--host", help="Host IP to listen on. If empty, all routable interfaces are accepted.", default="")
|
||||
parser.add_argument("--launch", help="Launches a web browser when load is completed.", action='store_true')
|
||||
parser.add_argument("--lora", help="LLAMA models only, applies a lora file on top of model. Experimental.", default="")
|
||||
|
||||
#os.environ["OMP_NUM_THREADS"] = '12'
|
||||
# psutil.cpu_count(logical=False)
|
||||
physical_core_limit = 1
|
||||
if os.cpu_count()!=None and os.cpu_count()>1:
|
||||
physical_core_limit = int(os.cpu_count()/2)
|
||||
|
@ -616,6 +615,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument("--stream", help="Uses pseudo streaming when generating tokens. Only for the Kobold Lite UI.", action='store_true')
|
||||
parser.add_argument("--smartcontext", help="Reserving a portion of context to try processing less frequently.", action='store_true')
|
||||
parser.add_argument("--unbantokens", help="Normally, KoboldAI prevents certain tokens such as EOS and Square Brackets. This flag unbans them.", action='store_true')
|
||||
parser.add_argument("--forceversion", help="If the model file format detection fails (e.g. rogue modified model) you can set this to override the detected format (enter desired version, e.g. 401 for GPTNeoX-Type2).", type=int, default=0)
|
||||
parser.add_argument("--nommap", help="If set, do not use mmap to load newer models", action='store_true')
|
||||
parser.add_argument("--usemlock", help="For Apple Systems. Force system to keep model in RAM rather than swapping or compressing", action='store_true')
|
||||
parser.add_argument("--noavx2", help="Do not use AVX2 instructions, a slower compatibility mode for older devices. Does not work with --clblast.", action='store_true')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue