batch size improvements
This commit is contained in:
parent
230a638512
commit
77463e0e9c
3 changed files with 7 additions and 5 deletions
1
expose.h
1
expose.h
|
@ -27,7 +27,6 @@ struct load_model_inputs
|
||||||
const int threads;
|
const int threads;
|
||||||
const int blasthreads;
|
const int blasthreads;
|
||||||
const int max_context_length;
|
const int max_context_length;
|
||||||
const int batch_size;
|
|
||||||
const bool low_vram;
|
const bool low_vram;
|
||||||
const bool use_mmq;
|
const bool use_mmq;
|
||||||
const char * executable_path;
|
const char * executable_path;
|
||||||
|
|
|
@ -82,6 +82,9 @@ static int n_batch = 8;
|
||||||
static bool useSmartContext = false;
|
static bool useSmartContext = false;
|
||||||
static bool useContextShift = false;
|
static bool useContextShift = false;
|
||||||
static int blasbatchsize = 512;
|
static int blasbatchsize = 512;
|
||||||
|
static int dontblasbatchsize = 16;
|
||||||
|
static int normalbatchsize = 32;
|
||||||
|
static int smallbatchsize = 8;
|
||||||
static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall
|
static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall
|
||||||
static std::string modelname;
|
static std::string modelname;
|
||||||
static std::vector<gpt_vocab::id> last_n_tokens;
|
static std::vector<gpt_vocab::id> last_n_tokens;
|
||||||
|
@ -671,7 +674,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
file_format = in_file_format;
|
file_format = in_file_format;
|
||||||
n_threads = params.n_threads = inputs.threads;
|
n_threads = params.n_threads = inputs.threads;
|
||||||
n_blasthreads = params.n_threads_batch = inputs.blasthreads;
|
n_blasthreads = params.n_threads_batch = inputs.blasthreads;
|
||||||
n_batch = params.n_batch = inputs.batch_size;
|
bool isGguf = (file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON);
|
||||||
|
|
||||||
|
n_batch = params.n_batch = (isGguf?normalbatchsize:smallbatchsize);
|
||||||
modelname = params.model = inputs.model_filename;
|
modelname = params.model = inputs.model_filename;
|
||||||
useSmartContext = inputs.use_smartcontext;
|
useSmartContext = inputs.use_smartcontext;
|
||||||
useContextShift = inputs.use_contextshift;
|
useContextShift = inputs.use_contextshift;
|
||||||
|
@ -679,7 +684,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
||||||
blasbatchsize = inputs.blasbatchsize;
|
blasbatchsize = inputs.blasbatchsize;
|
||||||
if(blasbatchsize<=0)
|
if(blasbatchsize<=0)
|
||||||
{
|
{
|
||||||
blasbatchsize = 8;
|
blasbatchsize = (isGguf?dontblasbatchsize:smallbatchsize);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto clamped_max_context_length = inputs.max_context_length;
|
auto clamped_max_context_length = inputs.max_context_length;
|
||||||
|
|
|
@ -23,7 +23,6 @@ class load_model_inputs(ctypes.Structure):
|
||||||
_fields_ = [("threads", ctypes.c_int),
|
_fields_ = [("threads", ctypes.c_int),
|
||||||
("blasthreads", ctypes.c_int),
|
("blasthreads", ctypes.c_int),
|
||||||
("max_context_length", ctypes.c_int),
|
("max_context_length", ctypes.c_int),
|
||||||
("batch_size", ctypes.c_int),
|
|
||||||
("low_vram", ctypes.c_bool),
|
("low_vram", ctypes.c_bool),
|
||||||
("use_mmq", ctypes.c_bool),
|
("use_mmq", ctypes.c_bool),
|
||||||
("executable_path", ctypes.c_char_p),
|
("executable_path", ctypes.c_char_p),
|
||||||
|
@ -229,7 +228,6 @@ def load_model(model_filename):
|
||||||
global args
|
global args
|
||||||
inputs = load_model_inputs()
|
inputs = load_model_inputs()
|
||||||
inputs.model_filename = model_filename.encode("UTF-8")
|
inputs.model_filename = model_filename.encode("UTF-8")
|
||||||
inputs.batch_size = 8
|
|
||||||
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
|
inputs.max_context_length = maxctx #initial value to use for ctx, can be overwritten
|
||||||
inputs.threads = args.threads
|
inputs.threads = args.threads
|
||||||
inputs.low_vram = (True if (args.usecublas and "lowvram" in args.usecublas) else False)
|
inputs.low_vram = (True if (args.usecublas and "lowvram" in args.usecublas) else False)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue