Fix clblast device selection on Linux
This commit is contained in:
parent
c3b810868d
commit
8fbfc80e03
1 changed files with 10 additions and 8 deletions
18
expose.cpp
18
expose.cpp
|
@ -23,6 +23,8 @@
|
||||||
extern "C"
|
extern "C"
|
||||||
{
|
{
|
||||||
|
|
||||||
|
std::string platformenv, deviceenv;
|
||||||
|
|
||||||
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
|
||||||
static FileFormat file_format = FileFormat::BADFORMAT;
|
static FileFormat file_format = FileFormat::BADFORMAT;
|
||||||
|
|
||||||
|
@ -33,15 +35,15 @@ extern "C"
|
||||||
|
|
||||||
//first digit is whether configured, second is platform, third is devices
|
//first digit is whether configured, second is platform, third is devices
|
||||||
int parseinfo = inputs.clblast_info;
|
int parseinfo = inputs.clblast_info;
|
||||||
|
|
||||||
std::string usingclblast = "KCPP_CLBLAST_CONFIGURED="+std::to_string(parseinfo>0?1:0);
|
std::string usingclblast = "KCPP_CLBLAST_CONFIGURED="+std::to_string(parseinfo>0?1:0);
|
||||||
putenv((char*)usingclblast.c_str());
|
putenv((char*)usingclblast.c_str());
|
||||||
|
|
||||||
parseinfo = parseinfo%100; //keep last 2 digits
|
parseinfo = parseinfo%100; //keep last 2 digits
|
||||||
int platform = parseinfo/10;
|
int platform = parseinfo/10;
|
||||||
int devices = parseinfo%10;
|
int devices = parseinfo%10;
|
||||||
std::string platformenv = "KCPP_CLBLAST_PLATFORM="+std::to_string(platform);
|
platformenv = "KCPP_CLBLAST_PLATFORM="+std::to_string(platform);
|
||||||
std::string deviceenv = "KCPP_CLBLAST_DEVICES="+std::to_string(devices);
|
deviceenv = "KCPP_CLBLAST_DEVICES="+std::to_string(devices);
|
||||||
putenv((char*)platformenv.c_str());
|
putenv((char*)platformenv.c_str());
|
||||||
putenv((char*)deviceenv.c_str());
|
putenv((char*)deviceenv.c_str());
|
||||||
|
|
||||||
|
@ -61,7 +63,7 @@ extern "C"
|
||||||
printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
||||||
lr = gpttype_load_model(inputs, file_format);
|
lr = gpttype_load_model(inputs, file_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD)
|
if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD)
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
|
@ -92,14 +94,14 @@ extern "C"
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
||||||
return llama_load_model(inputs, file_format);
|
return llama_load_model(inputs, file_format);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_outputs generate(const generation_inputs inputs, generation_outputs &output)
|
generation_outputs generate(const generation_inputs inputs, generation_outputs &output)
|
||||||
{
|
{
|
||||||
if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2 || file_format==FileFormat::GPTJ_3
|
if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2 || file_format==FileFormat::GPTJ_3
|
||||||
|| file_format==FileFormat::GPT2_1 || file_format==FileFormat::GPT2_2 )
|
|| file_format==FileFormat::GPT2_1 || file_format==FileFormat::GPT2_2 )
|
||||||
{
|
{
|
||||||
return gpttype_generate(inputs, output);
|
return gpttype_generate(inputs, output);
|
||||||
|
@ -107,6 +109,6 @@ extern "C"
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
return llama_generate(inputs, output);
|
return llama_generate(inputs, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue