Move model loading back to main.cpp

Signed-off-by: Thiago Padilha <thiago@padilha.cc>
This commit is contained in:
Thiago Padilha 2023-03-18 12:12:00 -03:00
parent e3648474d6
commit 1088d2dd04
No known key found for this signature in database
GPG key ID: 309C78E5ED1B3D5E
3 changed files with 71 additions and 56 deletions

View file

@ -713,36 +713,12 @@ void sigint_handler(int signo) {
} }
#endif #endif
const char * llama_print_system_info(void) { int llama_main(
static std::string s; gpt_params params,
gpt_vocab vocab,
s = ""; llama_model model,
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; int64_t t_load_us,
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; int64_t t_main_start_us) {
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
return s.c_str();
}
int llama_main(int argc, char ** argv) {
ggml_time_init();
const int64_t t_main_start_us = ggml_time_us();
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}
if (params.seed < 0) { if (params.seed < 0) {
params.seed = time(NULL); params.seed = time(NULL);
@ -758,29 +734,6 @@ int llama_main(int argc, char ** argv) {
// params.prompt = R"(// this function checks if the number n is prime // params.prompt = R"(// this function checks if the number n is prime
//bool is_prime(int n) {)"; //bool is_prime(int n) {)";
int64_t t_load_us = 0;
gpt_vocab vocab;
llama_model model;
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
t_load_us = ggml_time_us() - t_start_us;
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
int n_past = 0; int n_past = 0;
int64_t t_sample_us = 0; int64_t t_sample_us = 0;

View file

@ -6,6 +6,7 @@
#include <string> #include <string>
#include "ggml.h" #include "ggml.h"
#include "utils.h"
// default hparams (LLaMA 7B) // default hparams (LLaMA 7B)
@ -58,4 +59,10 @@ struct llama_model {
std::map<std::string, struct ggml_tensor *> tensors; std::map<std::string, struct ggml_tensor *> tensors;
}; };
int llama_main(int argc, char ** argv); int llama_main(
gpt_params params,
gpt_vocab vocab,
llama_model model,
int64_t t_load_us,
int64_t t_main_start_us);
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx);

View file

@ -1,5 +1,60 @@
#include "ggml.h"
#include "utils.h"
#include "llama.h" #include "llama.h"
int main(int argc, char ** argv) { const char * llama_print_system_info(void) {
return llama_main(argc, argv); static std::string s;
s = "";
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
return s.c_str();
}
int main(int argc, char ** argv) {
ggml_time_init();
const int64_t t_main_start_us = ggml_time_us();
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}
int64_t t_load_us = 0;
gpt_vocab vocab;
llama_model model;
// load the model
{
const int64_t t_start_us = ggml_time_us();
if (!llama_model_load(params.model, model, vocab, params.n_ctx)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
t_load_us = ggml_time_us() - t_start_us;
}
// print system information
{
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
return llama_main(params, vocab, model, t_main_start_us, t_load_us);
} }