wip pythia integration

This commit is contained in:
Concedo 2023-04-22 01:08:23 +08:00
parent 68898046c2
commit ef13443047
4 changed files with 34 additions and 3 deletions

View file

@ -115,6 +115,19 @@ extern "C"
return true;
}
}
else if(file_format==FileFormat::NEOX_1)
{
printf("\n---\nIdentified as GPT-NEO-X model: (ver %d)\nAttempting to Load...\n---\n", file_format);
ModelLoadResult lr = gpttype_load_model(inputs, file_format);
if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD)
{
return false;
}
else
{
return true;
}
}
else
{
printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format);

View file

@ -201,6 +201,18 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
return ModelLoadResult::SUCCESS;
}
else if(file_format==FileFormat::NEOX_1)
{
bool res = stablelm_model_load(params.model, neox_ctx, vocab);
if(!res)
{
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return ModelLoadResult::FAIL;
}
// determine the required inference memory per token:
stablelm_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
return ModelLoadResult::SUCCESS;
}
else
{
ModelLoadResult loadresult = gptj_model_load(params.model, gptj_ctx_v2, vocab);

View file

@ -113,7 +113,7 @@ void print_tok_vec(std::vector<float> &embd)
fileformat = FileFormat::GPTJ_3; //quantized format cannot be legacy type
}
}
if(vocabsiz==50257)
else if(vocabsiz==50257)
{
fileformat = FileFormat::GPT2_1;
uint32_t temp;
@ -125,8 +125,12 @@ void print_tok_vec(std::vector<float> &embd)
if(temp!=0 && temp!=1)
{
fileformat = FileFormat::GPT2_2; //quantized format cannot be legacy type
}
}
}
else if(vocabsiz < 32000 || vocabsiz > 36000)
{
//anything outside the llama v1 range is assumed to be NeoX
fileformat = FileFormat::NEOX_1;
}
}
else if(magic == 0x67676d66) //v2 format ggmf

View file

@ -28,6 +28,8 @@ enum FileFormat
GPT2_2=201,
RWKV_1=300,
NEOX_1=400,
};
enum ModelLoadResult