From 8fbfc80e037e8e95854848884f6857c29485746b Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 15 Apr 2023 12:02:36 +0200 Subject: [PATCH] Fix clblast device selection on Linux --- expose.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/expose.cpp b/expose.cpp index 893f427a0..d8b62b3e1 100644 --- a/expose.cpp +++ b/expose.cpp @@ -23,6 +23,8 @@ extern "C" { + std::string platformenv, deviceenv; + //return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt) static FileFormat file_format = FileFormat::BADFORMAT; @@ -33,15 +35,15 @@ extern "C" //first digit is whether configured, second is platform, third is devices int parseinfo = inputs.clblast_info; - + std::string usingclblast = "KCPP_CLBLAST_CONFIGURED="+std::to_string(parseinfo>0?1:0); putenv((char*)usingclblast.c_str()); - parseinfo = parseinfo%100; //keep last 2 digits + parseinfo = parseinfo%100; //keep last 2 digits int platform = parseinfo/10; int devices = parseinfo%10; - std::string platformenv = "KCPP_CLBLAST_PLATFORM="+std::to_string(platform); - std::string deviceenv = "KCPP_CLBLAST_DEVICES="+std::to_string(devices); + platformenv = "KCPP_CLBLAST_PLATFORM="+std::to_string(platform); + deviceenv = "KCPP_CLBLAST_DEVICES="+std::to_string(devices); putenv((char*)platformenv.c_str()); putenv((char*)deviceenv.c_str()); @@ -61,7 +63,7 @@ extern "C" printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format); lr = gpttype_load_model(inputs, file_format); } - + if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD) { return false; @@ -92,14 +94,14 @@ extern "C" } else { - printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format); + printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format); return llama_load_model(inputs, file_format); } } generation_outputs generate(const generation_inputs inputs, generation_outputs &output) { - if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2 || file_format==FileFormat::GPTJ_3 + if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2 || file_format==FileFormat::GPTJ_3 || file_format==FileFormat::GPT2_1 || file_format==FileFormat::GPT2_2 ) { return gpttype_generate(inputs, output); @@ -107,6 +109,6 @@ extern "C" else { return llama_generate(inputs, output); - } + } } }