Fix clblast device selection on Linux

This commit is contained in:
0cc4m 2023-04-15 12:02:36 +02:00
parent c3b810868d
commit 8fbfc80e03

View file

@ -23,6 +23,8 @@
extern "C"
{
std::string platformenv, deviceenv;
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
static FileFormat file_format = FileFormat::BADFORMAT;
@ -33,15 +35,15 @@ extern "C"
//first digit is whether configured, second is platform, third is devices
int parseinfo = inputs.clblast_info;
std::string usingclblast = "KCPP_CLBLAST_CONFIGURED="+std::to_string(parseinfo>0?1:0);
putenv((char*)usingclblast.c_str());
parseinfo = parseinfo%100; //keep last 2 digits
parseinfo = parseinfo%100; //keep last 2 digits
int platform = parseinfo/10;
int devices = parseinfo%10;
std::string platformenv = "KCPP_CLBLAST_PLATFORM="+std::to_string(platform);
std::string deviceenv = "KCPP_CLBLAST_DEVICES="+std::to_string(devices);
platformenv = "KCPP_CLBLAST_PLATFORM="+std::to_string(platform);
deviceenv = "KCPP_CLBLAST_DEVICES="+std::to_string(devices);
putenv((char*)platformenv.c_str());
putenv((char*)deviceenv.c_str());
@ -61,7 +63,7 @@ extern "C"
printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format);
lr = gpttype_load_model(inputs, file_format);
}
if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD)
{
return false;
@ -92,14 +94,14 @@ extern "C"
}
else
{
printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format);
printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format);
return llama_load_model(inputs, file_format);
}
}
generation_outputs generate(const generation_inputs inputs, generation_outputs &output)
{
if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2 || file_format==FileFormat::GPTJ_3
if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2 || file_format==FileFormat::GPTJ_3
|| file_format==FileFormat::GPT2_1 || file_format==FileFormat::GPT2_2 )
{
return gpttype_generate(inputs, output);
@ -107,6 +109,6 @@ extern "C"
else
{
return llama_generate(inputs, output);
}
}
}
}