Update llama_model_load() from master branch

This commit is contained in:
Thomas Antony 2023-03-19 16:59:17 -07:00
parent 5195fed013
commit 1c545e51ed

View file

@ -128,7 +128,8 @@ struct llama_context
/* Original code by @ggerganov */
// load the model's weights from a file
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
std::vector<char> f_buf(1024*1024);
@ -1071,9 +1072,12 @@ llama_context* llama_init_from_params(const gpt_params& params) {
llama_model model{};
gpt_vocab vocab{};
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
// Compute time taken to load model
const int64_t t_start = ggml_time_us();
bool ret = llama_model_load(params.model, model, vocab, 1024);
bool ret = llama_model_load(params.model, model, vocab, params.n_ctx, memory_type);
const int64_t t_end = ggml_time_us();
if(!ret)
{