only llama can use batch sizes above 256 to prevent unacceptably high memory usage

This commit is contained in:
Concedo 2023-04-23 15:57:06 +08:00
parent 432cc91649
commit 9129e937f9

View file

@ -358,8 +358,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
int original_threads = params.n_threads;
if (blasmode)
{
//for gpttype, GPT2 crashes above 256.
int bbs = blasbatchsize; //(blasbatchsize>256?256:blasbatchsize);
//for non llama, limit to 256
int bbs = blasbatchsize;
if (file_format != FileFormat::GGML && file_format != FileFormat::GGHF && file_format != FileFormat::GGJT)
{
bbs = (blasbatchsize > 256 ? 256 : blasbatchsize);
}
params.n_batch = bbs; //received reports of 1024 and above crashing on some models
params.n_threads = 1;
}