From 1088d2dd04761968073afe17ba4824e6c1c94703 Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Sat, 18 Mar 2023 12:12:00 -0300 Subject: [PATCH] Move model loading back to main.cpp Signed-off-by: Thiago Padilha --- llama.cpp | 59 ++++++------------------------------------------------- llama.h | 9 ++++++++- main.cpp | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 71 insertions(+), 56 deletions(-) diff --git a/llama.cpp b/llama.cpp index 21369cea2..35ec3e140 100644 --- a/llama.cpp +++ b/llama.cpp @@ -713,36 +713,12 @@ void sigint_handler(int signo) { } #endif -const char * llama_print_system_info(void) { - static std::string s; - - s = ""; - s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; - s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; - s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; - s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; - s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; - s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; - s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; - s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; - - return s.c_str(); -} - -int llama_main(int argc, char ** argv) { - ggml_time_init(); - const int64_t t_main_start_us = ggml_time_us(); - - gpt_params params; - params.model = "models/llama-7B/ggml-model.bin"; - - if (gpt_params_parse(argc, argv, params) == false) { - return 1; - } +int llama_main( + gpt_params params, + gpt_vocab vocab, + llama_model model, + int64_t t_load_us, + int64_t t_main_start_us) { if (params.seed < 0) { params.seed = time(NULL); @@ -758,29 +734,6 @@ int llama_main(int argc, char ** argv) { // params.prompt = R"(// this function checks if the number n is prime //bool is_prime(int n) {)"; - int64_t t_load_us = 0; - - gpt_vocab vocab; - llama_model model; - - // load the model - { - const int64_t t_start_us = ggml_time_us(); - if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { - fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); - return 1; - } - - t_load_us = ggml_time_us() - t_start_us; - } - - // print system information - { - fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); - } - int n_past = 0; int64_t t_sample_us = 0; diff --git a/llama.h b/llama.h index ea71c7402..9cacb613c 100644 --- a/llama.h +++ b/llama.h @@ -6,6 +6,7 @@ #include #include "ggml.h" +#include "utils.h" // default hparams (LLaMA 7B) @@ -58,4 +59,10 @@ struct llama_model { std::map tensors; }; -int llama_main(int argc, char ** argv); +int llama_main( + gpt_params params, + gpt_vocab vocab, + llama_model model, + int64_t t_load_us, + int64_t t_main_start_us); +bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx); diff --git a/main.cpp b/main.cpp index 8b9a3ff50..7106a8e19 100644 --- a/main.cpp +++ b/main.cpp @@ -1,5 +1,60 @@ +#include "ggml.h" +#include "utils.h" #include "llama.h" -int main(int argc, char ** argv) { - return llama_main(argc, argv); +const char * llama_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + + return s.c_str(); +} + +int main(int argc, char ** argv) { + ggml_time_init(); + const int64_t t_main_start_us = ggml_time_us(); + + gpt_params params; + params.model = "models/llama-7B/ggml-model.bin"; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + int64_t t_load_us = 0; + + gpt_vocab vocab; + llama_model model; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + if (!llama_model_load(params.model, model, vocab, params.n_ctx)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + } + + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + } + + return llama_main(params, vocab, model, t_main_start_us, t_load_us); }