diff --git a/expose.cpp b/expose.cpp index 436bf2e54..ab4998088 100644 --- a/expose.cpp +++ b/expose.cpp @@ -115,6 +115,19 @@ extern "C" return true; } } + else if(file_format==FileFormat::NEOX_1) + { + printf("\n---\nIdentified as GPT-NEO-X model: (ver %d)\nAttempting to Load...\n---\n", file_format); + ModelLoadResult lr = gpttype_load_model(inputs, file_format); + if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD) + { + return false; + } + else + { + return true; + } + } else { printf("\n---\nIdentified as LLAMA model: (ver %d)\nAttempting to Load...\n---\n", file_format); diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index fa41f8ccb..9e73ac9eb 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -201,6 +201,18 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return ModelLoadResult::SUCCESS; } + else if(file_format==FileFormat::NEOX_1) + { + bool res = stablelm_model_load(params.model, neox_ctx, vocab); + if(!res) + { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return ModelLoadResult::FAIL; + } + // determine the required inference memory per token: + stablelm_eval(neox_ctx, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + return ModelLoadResult::SUCCESS; + } else { ModelLoadResult loadresult = gptj_model_load(params.model, gptj_ctx_v2, vocab); diff --git a/model_adapter.cpp b/model_adapter.cpp index a8893f06c..66bbda14f 100644 --- a/model_adapter.cpp +++ b/model_adapter.cpp @@ -113,7 +113,7 @@ void print_tok_vec(std::vector &embd) fileformat = FileFormat::GPTJ_3; //quantized format cannot be legacy type } } - if(vocabsiz==50257) + else if(vocabsiz==50257) { fileformat = FileFormat::GPT2_1; uint32_t temp; @@ -125,8 +125,12 @@ void print_tok_vec(std::vector &embd) if(temp!=0 && temp!=1) { fileformat = FileFormat::GPT2_2; //quantized format cannot be legacy type - } - + } + } + else if(vocabsiz < 32000 || vocabsiz > 36000) + { + //anything outside the llama v1 range is assumed to be NeoX + fileformat = FileFormat::NEOX_1; } } else if(magic == 0x67676d66) //v2 format ggmf diff --git a/model_adapter.h b/model_adapter.h index 040196063..344643d2b 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -28,6 +28,8 @@ enum FileFormat GPT2_2=201, RWKV_1=300, + + NEOX_1=400, }; enum ModelLoadResult