wip pythia integration
This commit is contained in:
parent
68898046c2
commit
ef13443047
4 changed files with 34 additions and 3 deletions
13
expose.cpp
13
expose.cpp
|
@ -115,6 +115,19 @@ extern "C"
|
|||
return true;
|
||||
}
|
||||
}
|
||||
else if(file_format==FileFormat::NEOX_1)
|
||||
{
|
||||
printf("\n---\nIdentified as GPT-NEO-X model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
||||
ModelLoadResult lr = gpttype_load_model(inputs, file_format);
|
||||
if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format);
|
||||
|
|
|
@ -201,6 +201,18 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else if(file_format==FileFormat::NEOX_1)
|
||||
{
|
||||
bool res = stablelm_model_load(params.model, neox_ctx, vocab);
|
||||
if(!res)
|
||||
{
|
||||
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
|
||||
return ModelLoadResult::FAIL;
|
||||
}
|
||||
// determine the required inference memory per token:
|
||||
stablelm_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else
|
||||
{
|
||||
ModelLoadResult loadresult = gptj_model_load(params.model, gptj_ctx_v2, vocab);
|
||||
|
|
|
@ -113,7 +113,7 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
fileformat = FileFormat::GPTJ_3; //quantized format cannot be legacy type
|
||||
}
|
||||
}
|
||||
if(vocabsiz==50257)
|
||||
else if(vocabsiz==50257)
|
||||
{
|
||||
fileformat = FileFormat::GPT2_1;
|
||||
uint32_t temp;
|
||||
|
@ -125,8 +125,12 @@ void print_tok_vec(std::vector<float> &embd)
|
|||
if(temp!=0 && temp!=1)
|
||||
{
|
||||
fileformat = FileFormat::GPT2_2; //quantized format cannot be legacy type
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
else if(vocabsiz < 32000 || vocabsiz > 36000)
|
||||
{
|
||||
//anything outside the llama v1 range is assumed to be NeoX
|
||||
fileformat = FileFormat::NEOX_1;
|
||||
}
|
||||
}
|
||||
else if(magic == 0x67676d66) //v2 format ggmf
|
||||
|
|
|
@ -28,6 +28,8 @@ enum FileFormat
|
|||
GPT2_2=201,
|
||||
|
||||
RWKV_1=300,
|
||||
|
||||
NEOX_1=400,
|
||||
};
|
||||
|
||||
enum ModelLoadResult
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue