diff --git a/expose.h b/expose.h index eb116a5b4..40a4aff94 100644 --- a/expose.h +++ b/expose.h @@ -27,7 +27,6 @@ struct load_model_inputs const int threads; const int blasthreads; const int max_context_length; - const int batch_size; const bool low_vram; const bool use_mmq; const char * executable_path; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index befe360f5..90ca74fab 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -82,6 +82,9 @@ static int n_batch = 8; static bool useSmartContext = false; static bool useContextShift = false; static int blasbatchsize = 512; +static int dontblasbatchsize = 16; +static int normalbatchsize = 32; +static int smallbatchsize = 8; static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall static std::string modelname; static std::vector last_n_tokens; @@ -671,7 +674,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in file_format = in_file_format; n_threads = params.n_threads = inputs.threads; n_blasthreads = params.n_threads_batch = inputs.blasthreads; - n_batch = params.n_batch = inputs.batch_size; + bool isGguf = (file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON); + + n_batch = params.n_batch = (isGguf?normalbatchsize:smallbatchsize); modelname = params.model = inputs.model_filename; useSmartContext = inputs.use_smartcontext; useContextShift = inputs.use_contextshift; @@ -679,7 +684,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in blasbatchsize = inputs.blasbatchsize; if(blasbatchsize<=0) { - blasbatchsize = 8; + blasbatchsize = (isGguf?dontblasbatchsize:smallbatchsize); } auto clamped_max_context_length = inputs.max_context_length; diff --git a/koboldcpp.py b/koboldcpp.py index 452c5dc89..eafa86e72 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -23,7 +23,6 @@ class load_model_inputs(ctypes.Structure): _fields_ = [("threads", ctypes.c_int), ("blasthreads", ctypes.c_int), ("max_context_length", ctypes.c_int), - ("batch_size", ctypes.c_int), ("low_vram", ctypes.c_bool), ("use_mmq", ctypes.c_bool), ("executable_path", ctypes.c_char_p), @@ -229,7 +228,6 @@ def load_model(model_filename): global args inputs = load_model_inputs() inputs.model_filename = model_filename.encode("UTF-8") - inputs.batch_size = 8 inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten inputs.threads = args.threads inputs.low_vram = (True if (args.usecublas and "lowvram" in args.usecublas) else False)