Fix clblast device selection on Linux
This commit is contained in:
parent
c3b810868d
commit
8fbfc80e03
1 changed files with 10 additions and 8 deletions
18
expose.cpp
18
expose.cpp
|
@ -23,6 +23,8 @@
|
|||
extern "C"
|
||||
{
|
||||
|
||||
std::string platformenv, deviceenv;
|
||||
|
||||
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
||||
static FileFormat file_format = FileFormat::BADFORMAT;
|
||||
|
||||
|
@ -33,15 +35,15 @@ extern "C"
|
|||
|
||||
//first digit is whether configured, second is platform, third is devices
|
||||
int parseinfo = inputs.clblast_info;
|
||||
|
||||
|
||||
std::string usingclblast = "KCPP_CLBLAST_CONFIGURED="+std::to_string(parseinfo>0?1:0);
|
||||
putenv((char*)usingclblast.c_str());
|
||||
|
||||
parseinfo = parseinfo%100; //keep last 2 digits
|
||||
parseinfo = parseinfo%100; //keep last 2 digits
|
||||
int platform = parseinfo/10;
|
||||
int devices = parseinfo%10;
|
||||
std::string platformenv = "KCPP_CLBLAST_PLATFORM="+std::to_string(platform);
|
||||
std::string deviceenv = "KCPP_CLBLAST_DEVICES="+std::to_string(devices);
|
||||
platformenv = "KCPP_CLBLAST_PLATFORM="+std::to_string(platform);
|
||||
deviceenv = "KCPP_CLBLAST_DEVICES="+std::to_string(devices);
|
||||
putenv((char*)platformenv.c_str());
|
||||
putenv((char*)deviceenv.c_str());
|
||||
|
||||
|
@ -61,7 +63,7 @@ extern "C"
|
|||
printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
||||
lr = gpttype_load_model(inputs, file_format);
|
||||
}
|
||||
|
||||
|
||||
if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD)
|
||||
{
|
||||
return false;
|
||||
|
@ -92,14 +94,14 @@ extern "C"
|
|||
}
|
||||
else
|
||||
{
|
||||
printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
||||
printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
||||
return llama_load_model(inputs, file_format);
|
||||
}
|
||||
}
|
||||
|
||||
generation_outputs generate(const generation_inputs inputs, generation_outputs &output)
|
||||
{
|
||||
if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2 || file_format==FileFormat::GPTJ_3
|
||||
if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2 || file_format==FileFormat::GPTJ_3
|
||||
|| file_format==FileFormat::GPT2_1 || file_format==FileFormat::GPT2_2 )
|
||||
{
|
||||
return gpttype_generate(inputs, output);
|
||||
|
@ -107,6 +109,6 @@ extern "C"
|
|||
else
|
||||
{
|
||||
return llama_generate(inputs, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue